From d37f642f359cb7fd7405831e675abb93fd4704e2 Mon Sep 17 00:00:00 2001
From: Yuchen Zeng <zyc@google.com>
Date: Fri, 9 Sep 2016 20:05:37 -0700
Subject: [PATCH] Support server streaming

Skip unparsable input

Add tests for uni-directional stream calls

Simplify client stream handling
---
 test/cpp/util/cli_call.cc          |  42 ++++++------
 test/cpp/util/cli_call.h           |   3 +-
 test/cpp/util/grpc_cli.cc          |   4 +-
 test/cpp/util/grpc_tool.cc         | 105 ++++++++++++++++-------------
 test/cpp/util/grpc_tool_test.cc    |  95 ++++++++++++++++++++++++++
 test/cpp/util/proto_file_parser.cc |   1 -
 6 files changed, 178 insertions(+), 72 deletions(-)

diff --git a/test/cpp/util/cli_call.cc b/test/cpp/util/cli_call.cc
index d9232ec4b6..1101abe3c9 100644
--- a/test/cpp/util/cli_call.cc
+++ b/test/cpp/util/cli_call.cc
@@ -48,8 +48,6 @@ namespace {
 void* tag(int i) { return (void*)(intptr_t)i; }
 }  // namespace
 
-enum CliCall::CallStatus : intptr_t { CREATE, PROCESS, FINISH };
-
 Status CliCall::Call(std::shared_ptr<grpc::Channel> channel,
                      const grpc::string& method, const grpc::string& request,
                      grpc::string* response,
@@ -59,7 +57,9 @@ Status CliCall::Call(std::shared_ptr<grpc::Channel> channel,
   CliCall call(channel, method, metadata);
   call.Write(request);
   call.WritesDone();
-  call.Read(response, server_initial_metadata);
+  if (!call.Read(response, server_initial_metadata)) {
+    fprintf(stderr, "Failed to read response.\n");
+  }
   return call.Finish(server_trailing_metadata);
 }
 
@@ -92,36 +92,36 @@ void CliCall::Write(const grpc::string& request) {
   GPR_ASSERT(ok);
 }
 
-void CliCall::Read(grpc::string* response,
+bool CliCall::Read(grpc::string* response,
                    IncomingMetadataContainer* server_initial_metadata) {
   void* got_tag;
   bool ok;
 
   grpc::ByteBuffer recv_buffer;
-  call_->Read(&recv_buffer, tag(4));
-  cq_.Next(&got_tag, &ok);
-  if (!ok) {
-    fprintf(stderr, "Failed to read response.");
-  } else {
-    std::vector<grpc::Slice> slices;
-    (void)recv_buffer.Dump(&slices);
-
-    response->clear();
-    for (size_t i = 0; i < slices.size(); i++) {
-      response->append(reinterpret_cast<const char*>(slices[i].begin()),
-                       slices[i].size());
-    }
-    if (server_initial_metadata) {
-      *server_initial_metadata = ctx_.GetServerInitialMetadata();
-    }
+  call_->Read(&recv_buffer, tag(3));
+
+  if (!cq_.Next(&got_tag, &ok) || !ok) {
+    return false;
+  }
+  std::vector<grpc::Slice> slices;
+  recv_buffer.Dump(&slices);
+
+  response->clear();
+  for (size_t i = 0; i < slices.size(); i++) {
+    response->append(reinterpret_cast<const char*>(slices[i].begin()),
+                     slices[i].size());
+  }
+  if (server_initial_metadata) {
+    *server_initial_metadata = ctx_.GetServerInitialMetadata();
   }
+  return true;
 }
 
 void CliCall::WritesDone() {
   void* got_tag;
   bool ok;
 
-  call_->WritesDone(tag(3));
+  call_->WritesDone(tag(4));
   cq_.Next(&got_tag, &ok);
   GPR_ASSERT(ok);
 }
diff --git a/test/cpp/util/cli_call.h b/test/cpp/util/cli_call.h
index 3f328309a7..34fa88433f 100644
--- a/test/cpp/util/cli_call.h
+++ b/test/cpp/util/cli_call.h
@@ -68,13 +68,12 @@ class CliCall final {
 
   void WritesDone();
 
-  void Read(grpc::string* response,
+  bool Read(grpc::string* response,
             IncomingMetadataContainer* server_initial_metadata);
 
   Status Finish(IncomingMetadataContainer* server_trailing_metadata);
 
  private:
-  enum CallStatus : intptr_t;
   std::unique_ptr<grpc::GenericStub> stub_;
   grpc::ClientContext ctx_;
   std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_;
diff --git a/test/cpp/util/grpc_cli.cc b/test/cpp/util/grpc_cli.cc
index fe248601ee..fe68ccb619 100644
--- a/test/cpp/util/grpc_cli.cc
+++ b/test/cpp/util/grpc_cli.cc
@@ -83,10 +83,10 @@ DEFINE_string(outfile, "", "Output file (default is stdout)");
 static bool SimplePrint(const grpc::string& outfile,
                         const grpc::string& output) {
   if (outfile.empty()) {
-    std::cout << output;
+    std::cout << output << std::endl;
   } else {
     std::ofstream output_file(outfile, std::ios::trunc | std::ios::binary);
-    output_file << output;
+    output_file << output << std::endl;
     output_file.close();
   }
   return true;
diff --git a/test/cpp/util/grpc_tool.cc b/test/cpp/util/grpc_tool.cc
index 8082d6027b..762f8e8c23 100644
--- a/test/cpp/util/grpc_tool.cc
+++ b/test/cpp/util/grpc_tool.cc
@@ -418,6 +418,7 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
   grpc::string formatted_method_name;
   std::unique_ptr<grpc::testing::ProtoFileParser> parser;
   grpc::string serialized_request_proto;
+  bool print_mode = false;
 
   std::shared_ptr<grpc::Channel> channel =
       FLAGS_remotedb
@@ -435,17 +436,19 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
   }
 
   if (parser->IsStreaming(method_name, true /* is_request */)) {
-    fprintf(stderr, "streaming request\n");
+    // TODO(zyc): Support BidiStream
+    if (parser->IsStreaming(method_name, false /* is_request */)) {
+      fprintf(stderr,
+              "Bidirectional-streaming method is not supported.");
+      return false;
+    }
+
     std::istream* input_stream;
     std::ifstream input_file;
 
     if (argc == 3) {
       request_text = argv[2];
-      if (!FLAGS_infile.empty()) {
-        fprintf(stderr, "warning: request given in argv, ignoring --infile\n");
-      }
     }
-    // std::stringstream input_stream;
 
     std::multimap<grpc::string, grpc::string> client_metadata;
     ParseMetadataFlag(&client_metadata);
@@ -455,47 +458,47 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
 
     if (FLAGS_infile.empty()) {
       if (isatty(STDIN_FILENO)) {
-        fprintf(stderr, "reading request message from stdin...\n");
+        print_mode = true;
+        fprintf(stderr, "reading streaming request message from stdin...\n");
       }
       input_stream = &std::cin;
-      // rdbuf = std::cin.rdbuf();
-      // input_stream.rdbuf(std::cin.rdbuf());
-      // input_stream << std::cin.rdbuf();
-
     } else {
       input_file.open(FLAGS_infile, std::ios::in | std::ios::binary);
-      // rdbuf = input_file.rdbuf();
-      // input_stream.rdbuf(input_file.rdbuf());
       input_stream = &input_file;
-      // input_file.close();
     }
-    // request_text = input_stream.str();
 
     std::stringstream request_ss;
     grpc::string line;
-    while (!input_stream->eof() && getline(*input_stream, line)) {
-      if (line.length() == 0) {
-        // request_text = request_ss.str();
+    while (!request_text.empty() ||
+           (!input_stream->eof() && getline(*input_stream, line))) {
+      if (!request_text.empty()) {
         if (FLAGS_binary_input) {
-          serialized_request_proto = request_ss.str();
+          serialized_request_proto = request_text;
+          request_text.clear();
         } else {
           serialized_request_proto = parser->GetSerializedProtoFromMethod(
-              method_name, request_ss.str(), true /* is_request */);
+              method_name, request_text, true /* is_request */);
+          request_text.clear();
           if (parser->HasError()) {
-            return false;
+            if (print_mode) {
+              fprintf(stderr, "Failed to parse request.\n");
+            }
+            continue;
           }
         }
 
-        request_ss.str(grpc::string());
-        request_ss.clear();
-
-        grpc::string response_text = parser->GetTextFormatFromMethod(
-            method_name, serialized_request_proto, true /* is_request */);
         call.Write(serialized_request_proto);
-
-        fprintf(stderr, "%s", response_text.c_str());
+        if (print_mode) {
+          fprintf(stderr, "Request sent.\n");
+        }
       } else {
-        request_ss << line << ' ';
+        if (line.length() == 0) {
+          request_text = request_ss.str();
+          request_ss.str(grpc::string());
+          request_ss.clear();
+        } else {
+          request_ss << line << ' ';
+        }
       }
     }
     if (input_file.is_open()) {
@@ -507,7 +510,9 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
     grpc::string serialized_response_proto;
     std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata,
         server_trailing_metadata;
-    call.Read(&serialized_response_proto, &server_initial_metadata);
+    if (!call.Read(&serialized_response_proto, &server_trailing_metadata)) {
+      fprintf(stderr, "Failed to read response.\n");
+    }
     Status status = call.Finish(&server_trailing_metadata);
 
     PrintMetadata(server_initial_metadata,
@@ -524,7 +529,7 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
         if (parser->HasError()) {
           return false;
         }
-        output_ss << "Response: \n " << response_text << std::endl;
+        output_ss << response_text;
       }
     } else {
       fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
@@ -569,32 +574,40 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
         server_trailing_metadata;
     ParseMetadataFlag(&client_metadata);
     PrintMetadata(client_metadata, "Sending client initial metadata:");
-    grpc::Status status = grpc::testing::CliCall::Call(
-        channel, formated_method_name, serialized_request_proto,
-        &serialized_response_proto, client_metadata, &server_initial_metadata,
-        &server_trailing_metadata);
-    PrintMetadata(server_initial_metadata,
-                  "Received initial metadata from server:");
-    PrintMetadata(server_trailing_metadata,
-                  "Received trailing metadata from server:");
-    if (status.ok()) {
-      fprintf(stderr, "Rpc succeeded with OK status\n");
-      if (FLAGS_binary_output) {
-        output_ss << serialized_response_proto;
-      } else {
-        grpc::string response_text = parser->GetTextFormatFromMethod(
+
+    CliCall call(channel, formated_method_name, client_metadata);
+    call.Write(serialized_request_proto);
+    call.WritesDone();
+
+    for (bool receive_initial_metadata = true; call.Read(
+             &serialized_response_proto,
+             receive_initial_metadata ? &server_initial_metadata : nullptr);
+         receive_initial_metadata = false) {
+      if (!FLAGS_binary_output) {
+        serialized_response_proto = parser->GetTextFormatFromMethod(
             method_name, serialized_response_proto, false /* is_request */);
         if (parser->HasError()) {
           return false;
         }
-        output_ss << "Response: \n " << response_text << std::endl;
       }
+      if (receive_initial_metadata) {
+        PrintMetadata(server_initial_metadata,
+                      "Received initial metadata from server:");
+      }
+      if (!callback(serialized_response_proto)) {
+        return false;
+      }
+    }
+    Status status = call.Finish(&server_trailing_metadata);
+    if (status.ok()) {
+      fprintf(stderr, "Rpc succeeded with OK status\n");
+      return true;
     } else {
       fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
               status.error_code(), status.error_message().c_str());
+      return false;
     }
   }
-
   return callback(output_ss.str());
 }
 
diff --git a/test/cpp/util/grpc_tool_test.cc b/test/cpp/util/grpc_tool_test.cc
index 33ce611a60..e2eebd4089 100644
--- a/test/cpp/util/grpc_tool_test.cc
+++ b/test/cpp/util/grpc_tool_test.cc
@@ -102,6 +102,8 @@ DECLARE_bool(l);
 
 namespace {
 
+const int kNumResponseStreamsMsgs = 3;
+
 class TestCliCredentials final : public grpc::testing::CliCredentials {
  public:
   std::shared_ptr<grpc::ChannelCredentials> GetCredentials() const override {
@@ -137,6 +139,48 @@ class TestServiceImpl : public ::grpc::testing::EchoTestService::Service {
     response->set_message(request->message());
     return Status::OK;
   }
+
+  Status RequestStream(ServerContext* context,
+                       ServerReader<EchoRequest>* reader,
+                       EchoResponse* response) GRPC_OVERRIDE {
+    EchoRequest request;
+    response->set_message("");
+    if (!context->client_metadata().empty()) {
+      for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator
+               iter = context->client_metadata().begin();
+           iter != context->client_metadata().end(); ++iter) {
+        context->AddInitialMetadata(ToString(iter->first),
+                                    ToString(iter->second));
+      }
+    }
+    context->AddTrailingMetadata("trailing_key", "trailing_value");
+    while (reader->Read(&request)) {
+      response->mutable_message()->append(request.message());
+    }
+
+    return Status::OK;
+  }
+
+  Status ResponseStream(ServerContext* context, const EchoRequest* request,
+                        ServerWriter<EchoResponse>* writer) GRPC_OVERRIDE {
+    if (!context->client_metadata().empty()) {
+      for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator
+               iter = context->client_metadata().begin();
+           iter != context->client_metadata().end(); ++iter) {
+        context->AddInitialMetadata(ToString(iter->first),
+                                    ToString(iter->second));
+      }
+    }
+    context->AddTrailingMetadata("trailing_key", "trailing_value");
+
+    EchoResponse response;
+    for (int i = 0; i < kNumResponseStreamsMsgs; i++) {
+      response.set_message(request->message() + grpc::to_string(i));
+      writer->Write(response);
+    }
+
+    return Status::OK;
+  }
 };
 
 }  // namespace
@@ -388,6 +432,57 @@ TEST_F(GrpcToolTest, ParseCommand) {
   ShutdownServer();
 }
 
+TEST_F(GrpcToolTest, CallCommandRequestStream) {
+  // Test input: grpc_cli call localhost:<port> RequestStream "message:
+  // 'Hello0'"
+  std::stringstream output_stream;
+
+  const grpc::string server_address = SetUpServer();
+  const char* argv[] = {"grpc_cli", "call", server_address.c_str(),
+                        "RequestStream", "message: 'Hello0'"};
+
+  // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n"
+  std::streambuf* orig = std::cin.rdbuf();
+  std::istringstream ss("message: 'Hello1'\n\n message: 'Hello2'\n\n");
+  std::cin.rdbuf(ss.rdbuf());
+
+  EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+                                   std::bind(PrintStream, &output_stream,
+                                             std::placeholders::_1)));
+
+  // Expected output: "message: \"Hello0Hello1Hello2\""
+  EXPECT_TRUE(NULL != strstr(output_stream.str().c_str(),
+                             "message: \"Hello0Hello1Hello2\""));
+  std::cin.rdbuf(orig);
+  ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandResponseStream) {
+  // Test input: grpc_cli call localhost:<port> ResponseStream "message:
+  // 'Hello'"
+  std::stringstream output_stream;
+
+  const grpc::string server_address = SetUpServer();
+  const char* argv[] = {"grpc_cli", "call", server_address.c_str(),
+                        "ResponseStream", "message: 'Hello'"};
+
+  EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+                                   std::bind(PrintStream, &output_stream,
+                                             std::placeholders::_1)));
+
+  fprintf(stderr, "%s\n", output_stream.str().c_str());
+  // Expected output: "message: \"Hello{n}\""
+
+  for (int i = 0; i < kNumResponseStreamsMsgs; i++) {
+    grpc::string expected_response_text =
+        "message: \"Hello" + grpc::to_string(i) + "\"\n\n";
+    EXPECT_TRUE(NULL != strstr(output_stream.str().c_str(),
+                               expected_response_text.c_str()));
+  }
+
+  ShutdownServer();
+}
+
 TEST_F(GrpcToolTest, TooFewArguments) {
   // Test input "grpc_cli call Echo"
   std::stringstream output_stream;
diff --git a/test/cpp/util/proto_file_parser.cc b/test/cpp/util/proto_file_parser.cc
index 41bf88cc14..9f1f05595e 100644
--- a/test/cpp/util/proto_file_parser.cc
+++ b/test/cpp/util/proto_file_parser.cc
@@ -155,7 +155,6 @@ grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) {
     const auto* service_desc = *it;
     for (int j = 0; j < service_desc->method_count(); j++) {
       const auto* method_desc = service_desc->method(j);
-      fprintf(stderr, "%s\n", method_desc->full_name().c_str());
       if (MethodNameMatch(method_desc->full_name(), method)) {
         if (method_descriptor) {
           std::ostringstream error_stream;
-- 
GitLab