From f9329217b1ce334a19b8720f70a17ce8f5d5db23 Mon Sep 17 00:00:00 2001 From: Yuchen Zeng <zyc@google.com> Date: Fri, 9 Sep 2016 14:27:12 -0700 Subject: [PATCH] Support client streaming --- test/cpp/util/cli_call.cc | 85 ++++++++---- test/cpp/util/cli_call.h | 25 ++++ test/cpp/util/grpc_tool.cc | 213 +++++++++++++++++++++-------- test/cpp/util/proto_file_parser.cc | 27 ++++ test/cpp/util/proto_file_parser.h | 3 + 5 files changed, 271 insertions(+), 82 deletions(-) diff --git a/test/cpp/util/cli_call.cc b/test/cpp/util/cli_call.cc index a02a8b2ee2..d9232ec4b6 100644 --- a/test/cpp/util/cli_call.cc +++ b/test/cpp/util/cli_call.cc @@ -37,8 +37,6 @@ #include <grpc++/channel.h> #include <grpc++/client_context.h> -#include <grpc++/completion_queue.h> -#include <grpc++/generic/generic_stub.h> #include <grpc++/support/byte_buffer.h> #include <grpc/grpc.h> #include <grpc/slice.h> @@ -50,49 +48,61 @@ namespace { void* tag(int i) { return (void*)(intptr_t)i; } } // namespace +enum CliCall::CallStatus : intptr_t { CREATE, PROCESS, FINISH }; + Status CliCall::Call(std::shared_ptr<grpc::Channel> channel, const grpc::string& method, const grpc::string& request, grpc::string* response, const OutgoingMetadataContainer& metadata, IncomingMetadataContainer* server_initial_metadata, IncomingMetadataContainer* server_trailing_metadata) { - std::unique_ptr<grpc::GenericStub> stub(new grpc::GenericStub(channel)); - grpc::ClientContext ctx; + CliCall call(channel, method, metadata); + call.Write(request); + call.WritesDone(); + call.Read(response, server_initial_metadata); + return call.Finish(server_trailing_metadata); +} + +CliCall::CliCall(std::shared_ptr<grpc::Channel> channel, + const grpc::string& method, + const OutgoingMetadataContainer& metadata) + : stub_(new grpc::GenericStub(channel)) { if (!metadata.empty()) { for (OutgoingMetadataContainer::const_iterator iter = metadata.begin(); iter != metadata.end(); ++iter) { - ctx.AddMetadata(iter->first, iter->second); + ctx_.AddMetadata(iter->first, iter->second); } } - grpc::CompletionQueue cq; - std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call( - stub->Call(&ctx, method, &cq, tag(1))); + call_ = stub_->Call(&ctx_, method, &cq_, tag(1)); void* got_tag; bool ok; - cq.Next(&got_tag, &ok); + cq_.Next(&got_tag, &ok); GPR_ASSERT(ok); +} + +void CliCall::Write(const grpc::string& request) { + void* got_tag; + bool ok; grpc_slice s = grpc_slice_from_copied_string(request.c_str()); grpc::Slice req_slice(s, grpc::Slice::STEAL_REF); grpc::ByteBuffer send_buffer(&req_slice, 1); - call->Write(send_buffer, tag(2)); - cq.Next(&got_tag, &ok); - GPR_ASSERT(ok); - call->WritesDone(tag(3)); - cq.Next(&got_tag, &ok); + call_->Write(send_buffer, tag(2)); + cq_.Next(&got_tag, &ok); GPR_ASSERT(ok); +} + +void CliCall::Read(grpc::string* response, + IncomingMetadataContainer* server_initial_metadata) { + void* got_tag; + bool ok; + grpc::ByteBuffer recv_buffer; - call->Read(&recv_buffer, tag(4)); - cq.Next(&got_tag, &ok); + call_->Read(&recv_buffer, tag(4)); + cq_.Next(&got_tag, &ok); if (!ok) { - std::cout << "Failed to read response." << std::endl; - } - grpc::Status status; - call->Finish(&status, tag(5)); - cq.Next(&got_tag, &ok); - GPR_ASSERT(ok); - - if (status.ok()) { + fprintf(stderr, "Failed to read response."); + } else { std::vector<grpc::Slice> slices; (void)recv_buffer.Dump(&slices); @@ -101,10 +111,33 @@ Status CliCall::Call(std::shared_ptr<grpc::Channel> channel, response->append(reinterpret_cast<const char*>(slices[i].begin()), slices[i].size()); } + if (server_initial_metadata) { + *server_initial_metadata = ctx_.GetServerInitialMetadata(); + } + } +} + +void CliCall::WritesDone() { + void* got_tag; + bool ok; + + call_->WritesDone(tag(3)); + cq_.Next(&got_tag, &ok); + GPR_ASSERT(ok); +} + +Status CliCall::Finish(IncomingMetadataContainer* server_trailing_metadata) { + void* got_tag; + bool ok; + grpc::Status status; + + call_->Finish(&status, tag(5)); + cq_.Next(&got_tag, &ok); + GPR_ASSERT(ok); + if (server_trailing_metadata) { + *server_trailing_metadata = ctx_.GetServerTrailingMetadata(); } - *server_initial_metadata = ctx.GetServerInitialMetadata(); - *server_trailing_metadata = ctx.GetServerTrailingMetadata(); return status; } diff --git a/test/cpp/util/cli_call.h b/test/cpp/util/cli_call.h index 65da86bd4e..3f328309a7 100644 --- a/test/cpp/util/cli_call.h +++ b/test/cpp/util/cli_call.h @@ -37,10 +37,15 @@ #include <map> #include <grpc++/channel.h> +#include <grpc++/completion_queue.h> +#include <grpc++/generic/generic_stub.h> #include <grpc++/support/status.h> #include <grpc++/support/string_ref.h> namespace grpc { + +class ClientContext; + namespace testing { class CliCall final { @@ -48,12 +53,32 @@ class CliCall final { typedef std::multimap<grpc::string, grpc::string> OutgoingMetadataContainer; typedef std::multimap<grpc::string_ref, grpc::string_ref> IncomingMetadataContainer; + + CliCall(std::shared_ptr<grpc::Channel> channel, const grpc::string& method, + const OutgoingMetadataContainer& metadata); + static Status Call(std::shared_ptr<grpc::Channel> channel, const grpc::string& method, const grpc::string& request, grpc::string* response, const OutgoingMetadataContainer& metadata, IncomingMetadataContainer* server_initial_metadata, IncomingMetadataContainer* server_trailing_metadata); + + void Write(const grpc::string& request); + + void WritesDone(); + + void Read(grpc::string* response, + IncomingMetadataContainer* server_initial_metadata); + + Status Finish(IncomingMetadataContainer* server_trailing_metadata); + + private: + enum CallStatus : intptr_t; + std::unique_ptr<grpc::GenericStub> stub_; + grpc::ClientContext ctx_; + std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_; + grpc::CompletionQueue cq_; }; } // namespace testing diff --git a/test/cpp/util/grpc_tool.cc b/test/cpp/util/grpc_tool.cc index b9900ca1b7..8082d6027b 100644 --- a/test/cpp/util/grpc_tool.cc +++ b/test/cpp/util/grpc_tool.cc @@ -419,79 +419,180 @@ bool GrpcTool::CallMethod(int argc, const char** argv, std::unique_ptr<grpc::testing::ProtoFileParser> parser; grpc::string serialized_request_proto; - if (argc == 3) { - request_text = argv[2]; - if (!FLAGS_infile.empty()) { - fprintf(stderr, "warning: request given in argv, ignoring --infile\n"); + std::shared_ptr<grpc::Channel> channel = + FLAGS_remotedb + ? grpc::CreateChannel(server_address, cred.GetCredentials()) + : nullptr; + + parser.reset(new grpc::testing::ProtoFileParser(channel, FLAGS_proto_path, + FLAGS_protofiles)); + + grpc::string formated_method_name = + parser->GetFormatedMethodName(method_name); + + if (parser->HasError()) { + return false; + } + + if (parser->IsStreaming(method_name, true /* is_request */)) { + fprintf(stderr, "streaming request\n"); + std::istream* input_stream; + std::ifstream input_file; + + if (argc == 3) { + request_text = argv[2]; + if (!FLAGS_infile.empty()) { + fprintf(stderr, "warning: request given in argv, ignoring --infile\n"); + } } - } else { - std::stringstream input_stream; + // std::stringstream input_stream; + + std::multimap<grpc::string, grpc::string> client_metadata; + ParseMetadataFlag(&client_metadata); + PrintMetadata(client_metadata, "Sending client initial metadata:"); + + CliCall call(channel, formated_method_name, client_metadata); + if (FLAGS_infile.empty()) { if (isatty(STDIN_FILENO)) { fprintf(stderr, "reading request message from stdin...\n"); } - input_stream << std::cin.rdbuf(); + input_stream = &std::cin; + // rdbuf = std::cin.rdbuf(); + // input_stream.rdbuf(std::cin.rdbuf()); + // input_stream << std::cin.rdbuf(); + } else { - std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary); - input_stream << input_file.rdbuf(); + input_file.open(FLAGS_infile, std::ios::in | std::ios::binary); + // rdbuf = input_file.rdbuf(); + // input_stream.rdbuf(input_file.rdbuf()); + input_stream = &input_file; + // input_file.close(); + } + // request_text = input_stream.str(); + + std::stringstream request_ss; + grpc::string line; + while (!input_stream->eof() && getline(*input_stream, line)) { + if (line.length() == 0) { + // request_text = request_ss.str(); + if (FLAGS_binary_input) { + serialized_request_proto = request_ss.str(); + } else { + serialized_request_proto = parser->GetSerializedProtoFromMethod( + method_name, request_ss.str(), true /* is_request */); + if (parser->HasError()) { + return false; + } + } + + request_ss.str(grpc::string()); + request_ss.clear(); + + grpc::string response_text = parser->GetTextFormatFromMethod( + method_name, serialized_request_proto, true /* is_request */); + call.Write(serialized_request_proto); + + fprintf(stderr, "%s", response_text.c_str()); + } else { + request_ss << line << ' '; + } + } + if (input_file.is_open()) { input_file.close(); } - request_text = input_stream.str(); - } - std::shared_ptr<grpc::Channel> channel = - grpc::CreateChannel(server_address, cred.GetCredentials()); - if (!FLAGS_binary_input || !FLAGS_binary_output) { - parser.reset( - new grpc::testing::ProtoFileParser(FLAGS_remotedb ? channel : nullptr, - FLAGS_proto_path, FLAGS_protofiles)); - if (parser->HasError()) { - return false; + call.WritesDone(); + + grpc::string serialized_response_proto; + std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata, + server_trailing_metadata; + call.Read(&serialized_response_proto, &server_initial_metadata); + Status status = call.Finish(&server_trailing_metadata); + + PrintMetadata(server_initial_metadata, + "Received initial metadata from server:"); + PrintMetadata(server_trailing_metadata, + "Received trailing metadata from server:"); + if (status.ok()) { + fprintf(stderr, "Stream RPC succeeded with OK status\n"); + if (FLAGS_binary_output) { + output_ss << serialized_response_proto; + } else { + grpc::string response_text = parser->GetTextFormatFromMethod( + method_name, serialized_response_proto, false /* is_request */); + if (parser->HasError()) { + return false; + } + output_ss << "Response: \n " << response_text << std::endl; + } + } else { + fprintf(stderr, "Rpc failed with status code %d, error message: %s\n", + status.error_code(), status.error_message().c_str()); } - } - if (FLAGS_binary_input) { - serialized_request_proto = request_text; - formatted_method_name = method_name; - } else { - formatted_method_name = parser->GetFormattedMethodName(method_name); - serialized_request_proto = parser->GetSerializedProtoFromMethod( - method_name, request_text, true /* is_request */); - if (parser->HasError()) { - return false; + } else { // parser->IsStreaming(method_name, true /* is_request */) + if (argc == 3) { + request_text = argv[2]; + if (!FLAGS_infile.empty()) { + fprintf(stderr, "warning: request given in argv, ignoring --infile\n"); + } + } else { + std::stringstream input_stream; + if (FLAGS_infile.empty()) { + if (isatty(STDIN_FILENO)) { + fprintf(stderr, "reading request message from stdin...\n"); + } + input_stream << std::cin.rdbuf(); + } else { + std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary); + input_stream << input_file.rdbuf(); + input_file.close(); + } + request_text = input_stream.str(); } - } - fprintf(stderr, "connecting to %s\n", server_address.c_str()); - - grpc::string serialized_response_proto; - std::multimap<grpc::string, grpc::string> client_metadata; - std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata, - server_trailing_metadata; - ParseMetadataFlag(&client_metadata); - PrintMetadata(client_metadata, "Sending client initial metadata:"); - grpc::Status status = grpc::testing::CliCall::Call( - channel, formatted_method_name, serialized_request_proto, - &serialized_response_proto, client_metadata, &server_initial_metadata, - &server_trailing_metadata); - PrintMetadata(server_initial_metadata, - "Received initial metadata from server:"); - PrintMetadata(server_trailing_metadata, - "Received trailing metadata from server:"); - if (status.ok()) { - fprintf(stderr, "Rpc succeeded with OK status\n"); - if (FLAGS_binary_output) { - output_ss << serialized_response_proto; + + if (FLAGS_binary_input) { + serialized_request_proto = request_text; } else { - grpc::string response_text = parser->GetTextFormatFromMethod( - method_name, serialized_response_proto, false /* is_request */); + serialized_request_proto = parser->GetSerializedProtoFromMethod( + method_name, request_text, true /* is_request */); if (parser->HasError()) { return false; } - output_ss << "Response: \n " << response_text << std::endl; } - } else { - fprintf(stderr, "Rpc failed with status code %d, error message: %s\n", - status.error_code(), status.error_message().c_str()); + fprintf(stderr, "connecting to %s\n", server_address.c_str()); + + grpc::string serialized_response_proto; + std::multimap<grpc::string, grpc::string> client_metadata; + std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata, + server_trailing_metadata; + ParseMetadataFlag(&client_metadata); + PrintMetadata(client_metadata, "Sending client initial metadata:"); + grpc::Status status = grpc::testing::CliCall::Call( + channel, formated_method_name, serialized_request_proto, + &serialized_response_proto, client_metadata, &server_initial_metadata, + &server_trailing_metadata); + PrintMetadata(server_initial_metadata, + "Received initial metadata from server:"); + PrintMetadata(server_trailing_metadata, + "Received trailing metadata from server:"); + if (status.ok()) { + fprintf(stderr, "Rpc succeeded with OK status\n"); + if (FLAGS_binary_output) { + output_ss << serialized_response_proto; + } else { + grpc::string response_text = parser->GetTextFormatFromMethod( + method_name, serialized_response_proto, false /* is_request */); + if (parser->HasError()) { + return false; + } + output_ss << "Response: \n " << response_text << std::endl; + } + } else { + fprintf(stderr, "Rpc failed with status code %d, error message: %s\n", + status.error_code(), status.error_message().c_str()); + } } return callback(output_ss.str()); diff --git a/test/cpp/util/proto_file_parser.cc b/test/cpp/util/proto_file_parser.cc index bc8a6083f4..41bf88cc14 100644 --- a/test/cpp/util/proto_file_parser.cc +++ b/test/cpp/util/proto_file_parser.cc @@ -144,12 +144,18 @@ ProtoFileParser::~ProtoFileParser() {} grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) { has_error_ = false; + + if (known_methods_.find(method) != known_methods_.end()) { + return known_methods_[method]; + } + const protobuf::MethodDescriptor* method_descriptor = nullptr; for (auto it = service_desc_list_.begin(); it != service_desc_list_.end(); it++) { const auto* service_desc = *it; for (int j = 0; j < service_desc->method_count(); j++) { const auto* method_desc = service_desc->method(j); + fprintf(stderr, "%s\n", method_desc->full_name().c_str()); if (MethodNameMatch(method_desc->full_name(), method)) { if (method_descriptor) { std::ostringstream error_stream; @@ -169,6 +175,8 @@ grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) { return ""; } + known_methods_[method] = method_descriptor->full_name(); + return method_descriptor->full_name(); } @@ -205,6 +213,25 @@ grpc::string ProtoFileParser::GetMessageTypeFromMethod( : method_desc->output_type()->full_name(); } +bool ProtoFileParser::IsStreaming(const grpc::string& method, bool is_request) { + has_error_ = false; + + grpc::string full_method_name = GetFullMethodName(method); + if (has_error_) { + return false; + } + + const protobuf::MethodDescriptor* method_desc = + desc_pool_->FindMethodByName(full_method_name); + if (!method_desc) { + LogError("Method not found"); + return false; + } + + return is_request ? method_desc->client_streaming() + : method_desc->server_streaming(); +} + grpc::string ProtoFileParser::GetSerializedProtoFromMethod( const grpc::string& method, const grpc::string& text_format_proto, bool is_request) { diff --git a/test/cpp/util/proto_file_parser.h b/test/cpp/util/proto_file_parser.h index c1070a37b5..23d311ef8f 100644 --- a/test/cpp/util/proto_file_parser.h +++ b/test/cpp/util/proto_file_parser.h @@ -84,6 +84,8 @@ class ProtoFileParser { const grpc::string& message_type_name, const grpc::string& serialized_proto); + bool IsStreaming(const grpc::string& method, bool is_request); + bool HasError() const { return has_error_; } void LogError(const grpc::string& error_msg); @@ -104,6 +106,7 @@ class ProtoFileParser { std::unique_ptr<protobuf::DynamicMessageFactory> dynamic_factory_; std::unique_ptr<grpc::protobuf::Message> request_prototype_; std::unique_ptr<grpc::protobuf::Message> response_prototype_; + std::unordered_map<grpc::string, grpc::string> known_methods_; std::vector<const protobuf::ServiceDescriptor*> service_desc_list_; }; -- GitLab