From f9329217b1ce334a19b8720f70a17ce8f5d5db23 Mon Sep 17 00:00:00 2001
From: Yuchen Zeng <zyc@google.com>
Date: Fri, 9 Sep 2016 14:27:12 -0700
Subject: [PATCH] Support client streaming

---
 test/cpp/util/cli_call.cc          |  85 ++++++++----
 test/cpp/util/cli_call.h           |  25 ++++
 test/cpp/util/grpc_tool.cc         | 213 +++++++++++++++++++++--------
 test/cpp/util/proto_file_parser.cc |  27 ++++
 test/cpp/util/proto_file_parser.h  |   3 +
 5 files changed, 271 insertions(+), 82 deletions(-)

diff --git a/test/cpp/util/cli_call.cc b/test/cpp/util/cli_call.cc
index a02a8b2ee2..d9232ec4b6 100644
--- a/test/cpp/util/cli_call.cc
+++ b/test/cpp/util/cli_call.cc
@@ -37,8 +37,6 @@
 
 #include <grpc++/channel.h>
 #include <grpc++/client_context.h>
-#include <grpc++/completion_queue.h>
-#include <grpc++/generic/generic_stub.h>
 #include <grpc++/support/byte_buffer.h>
 #include <grpc/grpc.h>
 #include <grpc/slice.h>
@@ -50,49 +48,61 @@ namespace {
 void* tag(int i) { return (void*)(intptr_t)i; }
 }  // namespace
 
+enum CliCall::CallStatus : intptr_t { CREATE, PROCESS, FINISH };
+
 Status CliCall::Call(std::shared_ptr<grpc::Channel> channel,
                      const grpc::string& method, const grpc::string& request,
                      grpc::string* response,
                      const OutgoingMetadataContainer& metadata,
                      IncomingMetadataContainer* server_initial_metadata,
                      IncomingMetadataContainer* server_trailing_metadata) {
-  std::unique_ptr<grpc::GenericStub> stub(new grpc::GenericStub(channel));
-  grpc::ClientContext ctx;
+  CliCall call(channel, method, metadata);
+  call.Write(request);
+  call.WritesDone();
+  call.Read(response, server_initial_metadata);
+  return call.Finish(server_trailing_metadata);
+}
+
+CliCall::CliCall(std::shared_ptr<grpc::Channel> channel,
+                 const grpc::string& method,
+                 const OutgoingMetadataContainer& metadata)
+    : stub_(new grpc::GenericStub(channel)) {
   if (!metadata.empty()) {
     for (OutgoingMetadataContainer::const_iterator iter = metadata.begin();
          iter != metadata.end(); ++iter) {
-      ctx.AddMetadata(iter->first, iter->second);
+      ctx_.AddMetadata(iter->first, iter->second);
     }
   }
-  grpc::CompletionQueue cq;
-  std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call(
-      stub->Call(&ctx, method, &cq, tag(1)));
+  call_ = stub_->Call(&ctx_, method, &cq_, tag(1));
   void* got_tag;
   bool ok;
-  cq.Next(&got_tag, &ok);
+  cq_.Next(&got_tag, &ok);
   GPR_ASSERT(ok);
+}
+
+void CliCall::Write(const grpc::string& request) {
+  void* got_tag;
+  bool ok;
 
   grpc_slice s = grpc_slice_from_copied_string(request.c_str());
   grpc::Slice req_slice(s, grpc::Slice::STEAL_REF);
   grpc::ByteBuffer send_buffer(&req_slice, 1);
-  call->Write(send_buffer, tag(2));
-  cq.Next(&got_tag, &ok);
-  GPR_ASSERT(ok);
-  call->WritesDone(tag(3));
-  cq.Next(&got_tag, &ok);
+  call_->Write(send_buffer, tag(2));
+  cq_.Next(&got_tag, &ok);
   GPR_ASSERT(ok);
+}
+
+void CliCall::Read(grpc::string* response,
+                   IncomingMetadataContainer* server_initial_metadata) {
+  void* got_tag;
+  bool ok;
+
   grpc::ByteBuffer recv_buffer;
-  call->Read(&recv_buffer, tag(4));
-  cq.Next(&got_tag, &ok);
+  call_->Read(&recv_buffer, tag(4));
+  cq_.Next(&got_tag, &ok);
   if (!ok) {
-    std::cout << "Failed to read response." << std::endl;
-  }
-  grpc::Status status;
-  call->Finish(&status, tag(5));
-  cq.Next(&got_tag, &ok);
-  GPR_ASSERT(ok);
-
-  if (status.ok()) {
+    fprintf(stderr, "Failed to read response.");
+  } else {
     std::vector<grpc::Slice> slices;
     (void)recv_buffer.Dump(&slices);
 
@@ -101,10 +111,33 @@ Status CliCall::Call(std::shared_ptr<grpc::Channel> channel,
       response->append(reinterpret_cast<const char*>(slices[i].begin()),
                        slices[i].size());
     }
+    if (server_initial_metadata) {
+      *server_initial_metadata = ctx_.GetServerInitialMetadata();
+    }
+  }
+}
+
+void CliCall::WritesDone() {
+  void* got_tag;
+  bool ok;
+
+  call_->WritesDone(tag(3));
+  cq_.Next(&got_tag, &ok);
+  GPR_ASSERT(ok);
+}
+
+Status CliCall::Finish(IncomingMetadataContainer* server_trailing_metadata) {
+  void* got_tag;
+  bool ok;
+  grpc::Status status;
+
+  call_->Finish(&status, tag(5));
+  cq_.Next(&got_tag, &ok);
+  GPR_ASSERT(ok);
+  if (server_trailing_metadata) {
+    *server_trailing_metadata = ctx_.GetServerTrailingMetadata();
   }
 
-  *server_initial_metadata = ctx.GetServerInitialMetadata();
-  *server_trailing_metadata = ctx.GetServerTrailingMetadata();
   return status;
 }
 
diff --git a/test/cpp/util/cli_call.h b/test/cpp/util/cli_call.h
index 65da86bd4e..3f328309a7 100644
--- a/test/cpp/util/cli_call.h
+++ b/test/cpp/util/cli_call.h
@@ -37,10 +37,15 @@
 #include <map>
 
 #include <grpc++/channel.h>
+#include <grpc++/completion_queue.h>
+#include <grpc++/generic/generic_stub.h>
 #include <grpc++/support/status.h>
 #include <grpc++/support/string_ref.h>
 
 namespace grpc {
+
+class ClientContext;
+
 namespace testing {
 
 class CliCall final {
@@ -48,12 +53,32 @@ class CliCall final {
   typedef std::multimap<grpc::string, grpc::string> OutgoingMetadataContainer;
   typedef std::multimap<grpc::string_ref, grpc::string_ref>
       IncomingMetadataContainer;
+
+  CliCall(std::shared_ptr<grpc::Channel> channel, const grpc::string& method,
+          const OutgoingMetadataContainer& metadata);
+
   static Status Call(std::shared_ptr<grpc::Channel> channel,
                      const grpc::string& method, const grpc::string& request,
                      grpc::string* response,
                      const OutgoingMetadataContainer& metadata,
                      IncomingMetadataContainer* server_initial_metadata,
                      IncomingMetadataContainer* server_trailing_metadata);
+
+  void Write(const grpc::string& request);
+
+  void WritesDone();
+
+  void Read(grpc::string* response,
+            IncomingMetadataContainer* server_initial_metadata);
+
+  Status Finish(IncomingMetadataContainer* server_trailing_metadata);
+
+ private:
+  enum CallStatus : intptr_t;
+  std::unique_ptr<grpc::GenericStub> stub_;
+  grpc::ClientContext ctx_;
+  std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_;
+  grpc::CompletionQueue cq_;
 };
 
 }  // namespace testing
diff --git a/test/cpp/util/grpc_tool.cc b/test/cpp/util/grpc_tool.cc
index b9900ca1b7..8082d6027b 100644
--- a/test/cpp/util/grpc_tool.cc
+++ b/test/cpp/util/grpc_tool.cc
@@ -419,79 +419,180 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
   std::unique_ptr<grpc::testing::ProtoFileParser> parser;
   grpc::string serialized_request_proto;
 
-  if (argc == 3) {
-    request_text = argv[2];
-    if (!FLAGS_infile.empty()) {
-      fprintf(stderr, "warning: request given in argv, ignoring --infile\n");
+  std::shared_ptr<grpc::Channel> channel =
+      FLAGS_remotedb
+          ? grpc::CreateChannel(server_address, cred.GetCredentials())
+          : nullptr;
+
+  parser.reset(new grpc::testing::ProtoFileParser(channel, FLAGS_proto_path,
+                                                  FLAGS_protofiles));
+
+  grpc::string formated_method_name =
+      parser->GetFormatedMethodName(method_name);
+
+  if (parser->HasError()) {
+    return false;
+  }
+
+  if (parser->IsStreaming(method_name, true /* is_request */)) {
+    fprintf(stderr, "streaming request\n");
+    std::istream* input_stream;
+    std::ifstream input_file;
+
+    if (argc == 3) {
+      request_text = argv[2];
+      if (!FLAGS_infile.empty()) {
+        fprintf(stderr, "warning: request given in argv, ignoring --infile\n");
+      }
     }
-  } else {
-    std::stringstream input_stream;
+    // std::stringstream input_stream;
+
+    std::multimap<grpc::string, grpc::string> client_metadata;
+    ParseMetadataFlag(&client_metadata);
+    PrintMetadata(client_metadata, "Sending client initial metadata:");
+
+    CliCall call(channel, formated_method_name, client_metadata);
+
     if (FLAGS_infile.empty()) {
       if (isatty(STDIN_FILENO)) {
         fprintf(stderr, "reading request message from stdin...\n");
       }
-      input_stream << std::cin.rdbuf();
+      input_stream = &std::cin;
+      // rdbuf = std::cin.rdbuf();
+      // input_stream.rdbuf(std::cin.rdbuf());
+      // input_stream << std::cin.rdbuf();
+
     } else {
-      std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary);
-      input_stream << input_file.rdbuf();
+      input_file.open(FLAGS_infile, std::ios::in | std::ios::binary);
+      // rdbuf = input_file.rdbuf();
+      // input_stream.rdbuf(input_file.rdbuf());
+      input_stream = &input_file;
+      // input_file.close();
+    }
+    // request_text = input_stream.str();
+
+    std::stringstream request_ss;
+    grpc::string line;
+    while (!input_stream->eof() && getline(*input_stream, line)) {
+      if (line.length() == 0) {
+        // request_text = request_ss.str();
+        if (FLAGS_binary_input) {
+          serialized_request_proto = request_ss.str();
+        } else {
+          serialized_request_proto = parser->GetSerializedProtoFromMethod(
+              method_name, request_ss.str(), true /* is_request */);
+          if (parser->HasError()) {
+            return false;
+          }
+        }
+
+        request_ss.str(grpc::string());
+        request_ss.clear();
+
+        grpc::string response_text = parser->GetTextFormatFromMethod(
+            method_name, serialized_request_proto, true /* is_request */);
+        call.Write(serialized_request_proto);
+
+        fprintf(stderr, "%s", response_text.c_str());
+      } else {
+        request_ss << line << ' ';
+      }
+    }
+    if (input_file.is_open()) {
       input_file.close();
     }
-    request_text = input_stream.str();
-  }
 
-  std::shared_ptr<grpc::Channel> channel =
-      grpc::CreateChannel(server_address, cred.GetCredentials());
-  if (!FLAGS_binary_input || !FLAGS_binary_output) {
-    parser.reset(
-        new grpc::testing::ProtoFileParser(FLAGS_remotedb ? channel : nullptr,
-                                           FLAGS_proto_path, FLAGS_protofiles));
-    if (parser->HasError()) {
-      return false;
+    call.WritesDone();
+
+    grpc::string serialized_response_proto;
+    std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata,
+        server_trailing_metadata;
+    call.Read(&serialized_response_proto, &server_initial_metadata);
+    Status status = call.Finish(&server_trailing_metadata);
+
+    PrintMetadata(server_initial_metadata,
+                  "Received initial metadata from server:");
+    PrintMetadata(server_trailing_metadata,
+                  "Received trailing metadata from server:");
+    if (status.ok()) {
+      fprintf(stderr, "Stream RPC succeeded with OK status\n");
+      if (FLAGS_binary_output) {
+        output_ss << serialized_response_proto;
+      } else {
+        grpc::string response_text = parser->GetTextFormatFromMethod(
+            method_name, serialized_response_proto, false /* is_request */);
+        if (parser->HasError()) {
+          return false;
+        }
+        output_ss << "Response: \n " << response_text << std::endl;
+      }
+    } else {
+      fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
+              status.error_code(), status.error_message().c_str());
     }
-  }
 
-  if (FLAGS_binary_input) {
-    serialized_request_proto = request_text;
-    formatted_method_name = method_name;
-  } else {
-    formatted_method_name = parser->GetFormattedMethodName(method_name);
-    serialized_request_proto = parser->GetSerializedProtoFromMethod(
-        method_name, request_text, true /* is_request */);
-    if (parser->HasError()) {
-      return false;
+  } else {  // parser->IsStreaming(method_name, true /* is_request */)
+    if (argc == 3) {
+      request_text = argv[2];
+      if (!FLAGS_infile.empty()) {
+        fprintf(stderr, "warning: request given in argv, ignoring --infile\n");
+      }
+    } else {
+      std::stringstream input_stream;
+      if (FLAGS_infile.empty()) {
+        if (isatty(STDIN_FILENO)) {
+          fprintf(stderr, "reading request message from stdin...\n");
+        }
+        input_stream << std::cin.rdbuf();
+      } else {
+        std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary);
+        input_stream << input_file.rdbuf();
+        input_file.close();
+      }
+      request_text = input_stream.str();
     }
-  }
-  fprintf(stderr, "connecting to %s\n", server_address.c_str());
-
-  grpc::string serialized_response_proto;
-  std::multimap<grpc::string, grpc::string> client_metadata;
-  std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata,
-      server_trailing_metadata;
-  ParseMetadataFlag(&client_metadata);
-  PrintMetadata(client_metadata, "Sending client initial metadata:");
-  grpc::Status status = grpc::testing::CliCall::Call(
-      channel, formatted_method_name, serialized_request_proto,
-      &serialized_response_proto, client_metadata, &server_initial_metadata,
-      &server_trailing_metadata);
-  PrintMetadata(server_initial_metadata,
-                "Received initial metadata from server:");
-  PrintMetadata(server_trailing_metadata,
-                "Received trailing metadata from server:");
-  if (status.ok()) {
-    fprintf(stderr, "Rpc succeeded with OK status\n");
-    if (FLAGS_binary_output) {
-      output_ss << serialized_response_proto;
+
+    if (FLAGS_binary_input) {
+      serialized_request_proto = request_text;
     } else {
-      grpc::string response_text = parser->GetTextFormatFromMethod(
-          method_name, serialized_response_proto, false /* is_request */);
+      serialized_request_proto = parser->GetSerializedProtoFromMethod(
+          method_name, request_text, true /* is_request */);
       if (parser->HasError()) {
         return false;
       }
-      output_ss << "Response: \n " << response_text << std::endl;
     }
-  } else {
-    fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
-            status.error_code(), status.error_message().c_str());
+    fprintf(stderr, "connecting to %s\n", server_address.c_str());
+
+    grpc::string serialized_response_proto;
+    std::multimap<grpc::string, grpc::string> client_metadata;
+    std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata,
+        server_trailing_metadata;
+    ParseMetadataFlag(&client_metadata);
+    PrintMetadata(client_metadata, "Sending client initial metadata:");
+    grpc::Status status = grpc::testing::CliCall::Call(
+        channel, formated_method_name, serialized_request_proto,
+        &serialized_response_proto, client_metadata, &server_initial_metadata,
+        &server_trailing_metadata);
+    PrintMetadata(server_initial_metadata,
+                  "Received initial metadata from server:");
+    PrintMetadata(server_trailing_metadata,
+                  "Received trailing metadata from server:");
+    if (status.ok()) {
+      fprintf(stderr, "Rpc succeeded with OK status\n");
+      if (FLAGS_binary_output) {
+        output_ss << serialized_response_proto;
+      } else {
+        grpc::string response_text = parser->GetTextFormatFromMethod(
+            method_name, serialized_response_proto, false /* is_request */);
+        if (parser->HasError()) {
+          return false;
+        }
+        output_ss << "Response: \n " << response_text << std::endl;
+      }
+    } else {
+      fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
+              status.error_code(), status.error_message().c_str());
+    }
   }
 
   return callback(output_ss.str());
diff --git a/test/cpp/util/proto_file_parser.cc b/test/cpp/util/proto_file_parser.cc
index bc8a6083f4..41bf88cc14 100644
--- a/test/cpp/util/proto_file_parser.cc
+++ b/test/cpp/util/proto_file_parser.cc
@@ -144,12 +144,18 @@ ProtoFileParser::~ProtoFileParser() {}
 
 grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) {
   has_error_ = false;
+
+  if (known_methods_.find(method) != known_methods_.end()) {
+    return known_methods_[method];
+  }
+
   const protobuf::MethodDescriptor* method_descriptor = nullptr;
   for (auto it = service_desc_list_.begin(); it != service_desc_list_.end();
        it++) {
     const auto* service_desc = *it;
     for (int j = 0; j < service_desc->method_count(); j++) {
       const auto* method_desc = service_desc->method(j);
+      fprintf(stderr, "%s\n", method_desc->full_name().c_str());
       if (MethodNameMatch(method_desc->full_name(), method)) {
         if (method_descriptor) {
           std::ostringstream error_stream;
@@ -169,6 +175,8 @@ grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) {
     return "";
   }
 
+  known_methods_[method] = method_descriptor->full_name();
+
   return method_descriptor->full_name();
 }
 
@@ -205,6 +213,25 @@ grpc::string ProtoFileParser::GetMessageTypeFromMethod(
                     : method_desc->output_type()->full_name();
 }
 
+bool ProtoFileParser::IsStreaming(const grpc::string& method, bool is_request) {
+  has_error_ = false;
+
+  grpc::string full_method_name = GetFullMethodName(method);
+  if (has_error_) {
+    return false;
+  }
+
+  const protobuf::MethodDescriptor* method_desc =
+      desc_pool_->FindMethodByName(full_method_name);
+  if (!method_desc) {
+    LogError("Method not found");
+    return false;
+  }
+
+  return is_request ? method_desc->client_streaming()
+                    : method_desc->server_streaming();
+}
+
 grpc::string ProtoFileParser::GetSerializedProtoFromMethod(
     const grpc::string& method, const grpc::string& text_format_proto,
     bool is_request) {
diff --git a/test/cpp/util/proto_file_parser.h b/test/cpp/util/proto_file_parser.h
index c1070a37b5..23d311ef8f 100644
--- a/test/cpp/util/proto_file_parser.h
+++ b/test/cpp/util/proto_file_parser.h
@@ -84,6 +84,8 @@ class ProtoFileParser {
       const grpc::string& message_type_name,
       const grpc::string& serialized_proto);
 
+  bool IsStreaming(const grpc::string& method, bool is_request);
+
   bool HasError() const { return has_error_; }
 
   void LogError(const grpc::string& error_msg);
@@ -104,6 +106,7 @@ class ProtoFileParser {
   std::unique_ptr<protobuf::DynamicMessageFactory> dynamic_factory_;
   std::unique_ptr<grpc::protobuf::Message> request_prototype_;
   std::unique_ptr<grpc::protobuf::Message> response_prototype_;
+  std::unordered_map<grpc::string, grpc::string> known_methods_;
   std::vector<const protobuf::ServiceDescriptor*> service_desc_list_;
 };
 
-- 
GitLab