Skip to content
Snippets Groups Projects
Commit f9329217 authored by Yuchen Zeng's avatar Yuchen Zeng
Browse files

Support client streaming

parent 47f1f9e1
No related branches found
No related tags found
No related merge requests found
...@@ -37,8 +37,6 @@ ...@@ -37,8 +37,6 @@
#include <grpc++/channel.h> #include <grpc++/channel.h>
#include <grpc++/client_context.h> #include <grpc++/client_context.h>
#include <grpc++/completion_queue.h>
#include <grpc++/generic/generic_stub.h>
#include <grpc++/support/byte_buffer.h> #include <grpc++/support/byte_buffer.h>
#include <grpc/grpc.h> #include <grpc/grpc.h>
#include <grpc/slice.h> #include <grpc/slice.h>
...@@ -50,49 +48,61 @@ namespace { ...@@ -50,49 +48,61 @@ namespace {
void* tag(int i) { return (void*)(intptr_t)i; } void* tag(int i) { return (void*)(intptr_t)i; }
} // namespace } // namespace
enum CliCall::CallStatus : intptr_t { CREATE, PROCESS, FINISH };
Status CliCall::Call(std::shared_ptr<grpc::Channel> channel, Status CliCall::Call(std::shared_ptr<grpc::Channel> channel,
const grpc::string& method, const grpc::string& request, const grpc::string& method, const grpc::string& request,
grpc::string* response, grpc::string* response,
const OutgoingMetadataContainer& metadata, const OutgoingMetadataContainer& metadata,
IncomingMetadataContainer* server_initial_metadata, IncomingMetadataContainer* server_initial_metadata,
IncomingMetadataContainer* server_trailing_metadata) { IncomingMetadataContainer* server_trailing_metadata) {
std::unique_ptr<grpc::GenericStub> stub(new grpc::GenericStub(channel)); CliCall call(channel, method, metadata);
grpc::ClientContext ctx; call.Write(request);
call.WritesDone();
call.Read(response, server_initial_metadata);
return call.Finish(server_trailing_metadata);
}
CliCall::CliCall(std::shared_ptr<grpc::Channel> channel,
const grpc::string& method,
const OutgoingMetadataContainer& metadata)
: stub_(new grpc::GenericStub(channel)) {
if (!metadata.empty()) { if (!metadata.empty()) {
for (OutgoingMetadataContainer::const_iterator iter = metadata.begin(); for (OutgoingMetadataContainer::const_iterator iter = metadata.begin();
iter != metadata.end(); ++iter) { iter != metadata.end(); ++iter) {
ctx.AddMetadata(iter->first, iter->second); ctx_.AddMetadata(iter->first, iter->second);
} }
} }
grpc::CompletionQueue cq; call_ = stub_->Call(&ctx_, method, &cq_, tag(1));
std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call(
stub->Call(&ctx, method, &cq, tag(1)));
void* got_tag; void* got_tag;
bool ok; bool ok;
cq.Next(&got_tag, &ok); cq_.Next(&got_tag, &ok);
GPR_ASSERT(ok); GPR_ASSERT(ok);
}
void CliCall::Write(const grpc::string& request) {
void* got_tag;
bool ok;
grpc_slice s = grpc_slice_from_copied_string(request.c_str()); grpc_slice s = grpc_slice_from_copied_string(request.c_str());
grpc::Slice req_slice(s, grpc::Slice::STEAL_REF); grpc::Slice req_slice(s, grpc::Slice::STEAL_REF);
grpc::ByteBuffer send_buffer(&req_slice, 1); grpc::ByteBuffer send_buffer(&req_slice, 1);
call->Write(send_buffer, tag(2)); call_->Write(send_buffer, tag(2));
cq.Next(&got_tag, &ok); cq_.Next(&got_tag, &ok);
GPR_ASSERT(ok);
call->WritesDone(tag(3));
cq.Next(&got_tag, &ok);
GPR_ASSERT(ok); GPR_ASSERT(ok);
}
void CliCall::Read(grpc::string* response,
IncomingMetadataContainer* server_initial_metadata) {
void* got_tag;
bool ok;
grpc::ByteBuffer recv_buffer; grpc::ByteBuffer recv_buffer;
call->Read(&recv_buffer, tag(4)); call_->Read(&recv_buffer, tag(4));
cq.Next(&got_tag, &ok); cq_.Next(&got_tag, &ok);
if (!ok) { if (!ok) {
std::cout << "Failed to read response." << std::endl; fprintf(stderr, "Failed to read response.");
} } else {
grpc::Status status;
call->Finish(&status, tag(5));
cq.Next(&got_tag, &ok);
GPR_ASSERT(ok);
if (status.ok()) {
std::vector<grpc::Slice> slices; std::vector<grpc::Slice> slices;
(void)recv_buffer.Dump(&slices); (void)recv_buffer.Dump(&slices);
...@@ -101,10 +111,33 @@ Status CliCall::Call(std::shared_ptr<grpc::Channel> channel, ...@@ -101,10 +111,33 @@ Status CliCall::Call(std::shared_ptr<grpc::Channel> channel,
response->append(reinterpret_cast<const char*>(slices[i].begin()), response->append(reinterpret_cast<const char*>(slices[i].begin()),
slices[i].size()); slices[i].size());
} }
if (server_initial_metadata) {
*server_initial_metadata = ctx_.GetServerInitialMetadata();
}
}
}
void CliCall::WritesDone() {
void* got_tag;
bool ok;
call_->WritesDone(tag(3));
cq_.Next(&got_tag, &ok);
GPR_ASSERT(ok);
}
Status CliCall::Finish(IncomingMetadataContainer* server_trailing_metadata) {
void* got_tag;
bool ok;
grpc::Status status;
call_->Finish(&status, tag(5));
cq_.Next(&got_tag, &ok);
GPR_ASSERT(ok);
if (server_trailing_metadata) {
*server_trailing_metadata = ctx_.GetServerTrailingMetadata();
} }
*server_initial_metadata = ctx.GetServerInitialMetadata();
*server_trailing_metadata = ctx.GetServerTrailingMetadata();
return status; return status;
} }
......
...@@ -37,10 +37,15 @@ ...@@ -37,10 +37,15 @@
#include <map> #include <map>
#include <grpc++/channel.h> #include <grpc++/channel.h>
#include <grpc++/completion_queue.h>
#include <grpc++/generic/generic_stub.h>
#include <grpc++/support/status.h> #include <grpc++/support/status.h>
#include <grpc++/support/string_ref.h> #include <grpc++/support/string_ref.h>
namespace grpc { namespace grpc {
class ClientContext;
namespace testing { namespace testing {
class CliCall final { class CliCall final {
...@@ -48,12 +53,32 @@ class CliCall final { ...@@ -48,12 +53,32 @@ class CliCall final {
typedef std::multimap<grpc::string, grpc::string> OutgoingMetadataContainer; typedef std::multimap<grpc::string, grpc::string> OutgoingMetadataContainer;
typedef std::multimap<grpc::string_ref, grpc::string_ref> typedef std::multimap<grpc::string_ref, grpc::string_ref>
IncomingMetadataContainer; IncomingMetadataContainer;
CliCall(std::shared_ptr<grpc::Channel> channel, const grpc::string& method,
const OutgoingMetadataContainer& metadata);
static Status Call(std::shared_ptr<grpc::Channel> channel, static Status Call(std::shared_ptr<grpc::Channel> channel,
const grpc::string& method, const grpc::string& request, const grpc::string& method, const grpc::string& request,
grpc::string* response, grpc::string* response,
const OutgoingMetadataContainer& metadata, const OutgoingMetadataContainer& metadata,
IncomingMetadataContainer* server_initial_metadata, IncomingMetadataContainer* server_initial_metadata,
IncomingMetadataContainer* server_trailing_metadata); IncomingMetadataContainer* server_trailing_metadata);
void Write(const grpc::string& request);
void WritesDone();
void Read(grpc::string* response,
IncomingMetadataContainer* server_initial_metadata);
Status Finish(IncomingMetadataContainer* server_trailing_metadata);
private:
enum CallStatus : intptr_t;
std::unique_ptr<grpc::GenericStub> stub_;
grpc::ClientContext ctx_;
std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_;
grpc::CompletionQueue cq_;
}; };
} // namespace testing } // namespace testing
......
...@@ -419,79 +419,180 @@ bool GrpcTool::CallMethod(int argc, const char** argv, ...@@ -419,79 +419,180 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
std::unique_ptr<grpc::testing::ProtoFileParser> parser; std::unique_ptr<grpc::testing::ProtoFileParser> parser;
grpc::string serialized_request_proto; grpc::string serialized_request_proto;
if (argc == 3) { std::shared_ptr<grpc::Channel> channel =
request_text = argv[2]; FLAGS_remotedb
if (!FLAGS_infile.empty()) { ? grpc::CreateChannel(server_address, cred.GetCredentials())
fprintf(stderr, "warning: request given in argv, ignoring --infile\n"); : nullptr;
parser.reset(new grpc::testing::ProtoFileParser(channel, FLAGS_proto_path,
FLAGS_protofiles));
grpc::string formated_method_name =
parser->GetFormatedMethodName(method_name);
if (parser->HasError()) {
return false;
}
if (parser->IsStreaming(method_name, true /* is_request */)) {
fprintf(stderr, "streaming request\n");
std::istream* input_stream;
std::ifstream input_file;
if (argc == 3) {
request_text = argv[2];
if (!FLAGS_infile.empty()) {
fprintf(stderr, "warning: request given in argv, ignoring --infile\n");
}
} }
} else { // std::stringstream input_stream;
std::stringstream input_stream;
std::multimap<grpc::string, grpc::string> client_metadata;
ParseMetadataFlag(&client_metadata);
PrintMetadata(client_metadata, "Sending client initial metadata:");
CliCall call(channel, formated_method_name, client_metadata);
if (FLAGS_infile.empty()) { if (FLAGS_infile.empty()) {
if (isatty(STDIN_FILENO)) { if (isatty(STDIN_FILENO)) {
fprintf(stderr, "reading request message from stdin...\n"); fprintf(stderr, "reading request message from stdin...\n");
} }
input_stream << std::cin.rdbuf(); input_stream = &std::cin;
// rdbuf = std::cin.rdbuf();
// input_stream.rdbuf(std::cin.rdbuf());
// input_stream << std::cin.rdbuf();
} else { } else {
std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary); input_file.open(FLAGS_infile, std::ios::in | std::ios::binary);
input_stream << input_file.rdbuf(); // rdbuf = input_file.rdbuf();
// input_stream.rdbuf(input_file.rdbuf());
input_stream = &input_file;
// input_file.close();
}
// request_text = input_stream.str();
std::stringstream request_ss;
grpc::string line;
while (!input_stream->eof() && getline(*input_stream, line)) {
if (line.length() == 0) {
// request_text = request_ss.str();
if (FLAGS_binary_input) {
serialized_request_proto = request_ss.str();
} else {
serialized_request_proto = parser->GetSerializedProtoFromMethod(
method_name, request_ss.str(), true /* is_request */);
if (parser->HasError()) {
return false;
}
}
request_ss.str(grpc::string());
request_ss.clear();
grpc::string response_text = parser->GetTextFormatFromMethod(
method_name, serialized_request_proto, true /* is_request */);
call.Write(serialized_request_proto);
fprintf(stderr, "%s", response_text.c_str());
} else {
request_ss << line << ' ';
}
}
if (input_file.is_open()) {
input_file.close(); input_file.close();
} }
request_text = input_stream.str();
}
std::shared_ptr<grpc::Channel> channel = call.WritesDone();
grpc::CreateChannel(server_address, cred.GetCredentials());
if (!FLAGS_binary_input || !FLAGS_binary_output) { grpc::string serialized_response_proto;
parser.reset( std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata,
new grpc::testing::ProtoFileParser(FLAGS_remotedb ? channel : nullptr, server_trailing_metadata;
FLAGS_proto_path, FLAGS_protofiles)); call.Read(&serialized_response_proto, &server_initial_metadata);
if (parser->HasError()) { Status status = call.Finish(&server_trailing_metadata);
return false;
PrintMetadata(server_initial_metadata,
"Received initial metadata from server:");
PrintMetadata(server_trailing_metadata,
"Received trailing metadata from server:");
if (status.ok()) {
fprintf(stderr, "Stream RPC succeeded with OK status\n");
if (FLAGS_binary_output) {
output_ss << serialized_response_proto;
} else {
grpc::string response_text = parser->GetTextFormatFromMethod(
method_name, serialized_response_proto, false /* is_request */);
if (parser->HasError()) {
return false;
}
output_ss << "Response: \n " << response_text << std::endl;
}
} else {
fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
status.error_code(), status.error_message().c_str());
} }
}
if (FLAGS_binary_input) { } else { // parser->IsStreaming(method_name, true /* is_request */)
serialized_request_proto = request_text; if (argc == 3) {
formatted_method_name = method_name; request_text = argv[2];
} else { if (!FLAGS_infile.empty()) {
formatted_method_name = parser->GetFormattedMethodName(method_name); fprintf(stderr, "warning: request given in argv, ignoring --infile\n");
serialized_request_proto = parser->GetSerializedProtoFromMethod( }
method_name, request_text, true /* is_request */); } else {
if (parser->HasError()) { std::stringstream input_stream;
return false; if (FLAGS_infile.empty()) {
if (isatty(STDIN_FILENO)) {
fprintf(stderr, "reading request message from stdin...\n");
}
input_stream << std::cin.rdbuf();
} else {
std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary);
input_stream << input_file.rdbuf();
input_file.close();
}
request_text = input_stream.str();
} }
}
fprintf(stderr, "connecting to %s\n", server_address.c_str()); if (FLAGS_binary_input) {
serialized_request_proto = request_text;
grpc::string serialized_response_proto;
std::multimap<grpc::string, grpc::string> client_metadata;
std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata,
server_trailing_metadata;
ParseMetadataFlag(&client_metadata);
PrintMetadata(client_metadata, "Sending client initial metadata:");
grpc::Status status = grpc::testing::CliCall::Call(
channel, formatted_method_name, serialized_request_proto,
&serialized_response_proto, client_metadata, &server_initial_metadata,
&server_trailing_metadata);
PrintMetadata(server_initial_metadata,
"Received initial metadata from server:");
PrintMetadata(server_trailing_metadata,
"Received trailing metadata from server:");
if (status.ok()) {
fprintf(stderr, "Rpc succeeded with OK status\n");
if (FLAGS_binary_output) {
output_ss << serialized_response_proto;
} else { } else {
grpc::string response_text = parser->GetTextFormatFromMethod( serialized_request_proto = parser->GetSerializedProtoFromMethod(
method_name, serialized_response_proto, false /* is_request */); method_name, request_text, true /* is_request */);
if (parser->HasError()) { if (parser->HasError()) {
return false; return false;
} }
output_ss << "Response: \n " << response_text << std::endl;
} }
} else { fprintf(stderr, "connecting to %s\n", server_address.c_str());
fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
status.error_code(), status.error_message().c_str()); grpc::string serialized_response_proto;
std::multimap<grpc::string, grpc::string> client_metadata;
std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata,
server_trailing_metadata;
ParseMetadataFlag(&client_metadata);
PrintMetadata(client_metadata, "Sending client initial metadata:");
grpc::Status status = grpc::testing::CliCall::Call(
channel, formated_method_name, serialized_request_proto,
&serialized_response_proto, client_metadata, &server_initial_metadata,
&server_trailing_metadata);
PrintMetadata(server_initial_metadata,
"Received initial metadata from server:");
PrintMetadata(server_trailing_metadata,
"Received trailing metadata from server:");
if (status.ok()) {
fprintf(stderr, "Rpc succeeded with OK status\n");
if (FLAGS_binary_output) {
output_ss << serialized_response_proto;
} else {
grpc::string response_text = parser->GetTextFormatFromMethod(
method_name, serialized_response_proto, false /* is_request */);
if (parser->HasError()) {
return false;
}
output_ss << "Response: \n " << response_text << std::endl;
}
} else {
fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
status.error_code(), status.error_message().c_str());
}
} }
return callback(output_ss.str()); return callback(output_ss.str());
......
...@@ -144,12 +144,18 @@ ProtoFileParser::~ProtoFileParser() {} ...@@ -144,12 +144,18 @@ ProtoFileParser::~ProtoFileParser() {}
grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) { grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) {
has_error_ = false; has_error_ = false;
if (known_methods_.find(method) != known_methods_.end()) {
return known_methods_[method];
}
const protobuf::MethodDescriptor* method_descriptor = nullptr; const protobuf::MethodDescriptor* method_descriptor = nullptr;
for (auto it = service_desc_list_.begin(); it != service_desc_list_.end(); for (auto it = service_desc_list_.begin(); it != service_desc_list_.end();
it++) { it++) {
const auto* service_desc = *it; const auto* service_desc = *it;
for (int j = 0; j < service_desc->method_count(); j++) { for (int j = 0; j < service_desc->method_count(); j++) {
const auto* method_desc = service_desc->method(j); const auto* method_desc = service_desc->method(j);
fprintf(stderr, "%s\n", method_desc->full_name().c_str());
if (MethodNameMatch(method_desc->full_name(), method)) { if (MethodNameMatch(method_desc->full_name(), method)) {
if (method_descriptor) { if (method_descriptor) {
std::ostringstream error_stream; std::ostringstream error_stream;
...@@ -169,6 +175,8 @@ grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) { ...@@ -169,6 +175,8 @@ grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) {
return ""; return "";
} }
known_methods_[method] = method_descriptor->full_name();
return method_descriptor->full_name(); return method_descriptor->full_name();
} }
...@@ -205,6 +213,25 @@ grpc::string ProtoFileParser::GetMessageTypeFromMethod( ...@@ -205,6 +213,25 @@ grpc::string ProtoFileParser::GetMessageTypeFromMethod(
: method_desc->output_type()->full_name(); : method_desc->output_type()->full_name();
} }
bool ProtoFileParser::IsStreaming(const grpc::string& method, bool is_request) {
has_error_ = false;
grpc::string full_method_name = GetFullMethodName(method);
if (has_error_) {
return false;
}
const protobuf::MethodDescriptor* method_desc =
desc_pool_->FindMethodByName(full_method_name);
if (!method_desc) {
LogError("Method not found");
return false;
}
return is_request ? method_desc->client_streaming()
: method_desc->server_streaming();
}
grpc::string ProtoFileParser::GetSerializedProtoFromMethod( grpc::string ProtoFileParser::GetSerializedProtoFromMethod(
const grpc::string& method, const grpc::string& text_format_proto, const grpc::string& method, const grpc::string& text_format_proto,
bool is_request) { bool is_request) {
......
...@@ -84,6 +84,8 @@ class ProtoFileParser { ...@@ -84,6 +84,8 @@ class ProtoFileParser {
const grpc::string& message_type_name, const grpc::string& message_type_name,
const grpc::string& serialized_proto); const grpc::string& serialized_proto);
bool IsStreaming(const grpc::string& method, bool is_request);
bool HasError() const { return has_error_; } bool HasError() const { return has_error_; }
void LogError(const grpc::string& error_msg); void LogError(const grpc::string& error_msg);
...@@ -104,6 +106,7 @@ class ProtoFileParser { ...@@ -104,6 +106,7 @@ class ProtoFileParser {
std::unique_ptr<protobuf::DynamicMessageFactory> dynamic_factory_; std::unique_ptr<protobuf::DynamicMessageFactory> dynamic_factory_;
std::unique_ptr<grpc::protobuf::Message> request_prototype_; std::unique_ptr<grpc::protobuf::Message> request_prototype_;
std::unique_ptr<grpc::protobuf::Message> response_prototype_; std::unique_ptr<grpc::protobuf::Message> response_prototype_;
std::unordered_map<grpc::string, grpc::string> known_methods_;
std::vector<const protobuf::ServiceDescriptor*> service_desc_list_; std::vector<const protobuf::ServiceDescriptor*> service_desc_list_;
}; };
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment