From db5b1cbc9483b78458ca529df5603bb15c982d63 Mon Sep 17 00:00:00 2001
From: Vijay Pai <vpai@google.com>
Date: Mon, 3 Oct 2016 09:18:46 -0700
Subject: [PATCH] Add all plumbing and typedef's for controlled server-side
 streaming

---
 .../grpc++/impl/codegen/method_handler_impl.h | 14 +++
 include/grpc++/impl/codegen/service_type.h    | 12 ++-
 include/grpc++/impl/codegen/sync_stream.h     | 39 +++++++-
 src/compiler/cpp_generator.cc                 | 97 ++++++++++++++++++-
 4 files changed, 154 insertions(+), 8 deletions(-)

diff --git a/include/grpc++/impl/codegen/method_handler_impl.h b/include/grpc++/impl/codegen/method_handler_impl.h
index 52f927631c..8f417f671a 100644
--- a/include/grpc++/impl/codegen/method_handler_impl.h
+++ b/include/grpc++/impl/codegen/method_handler_impl.h
@@ -236,6 +236,20 @@ class StreamedUnaryHandler
             ServerUnaryStreamer<RequestType, ResponseType>, true>(func) {}
 };
 
+template <class RequestType, class ResponseType>
+class SplitServerStreamingHandler
+    : public TemplatedBidiStreamingHandler<
+          ServerSplitStreamer<RequestType, ResponseType>, false> {
+ public:
+  explicit SplitServerStreamingHandler(
+      std::function<Status(ServerContext*,
+                           SplitServerStreamingHandler<RequestType,
+			   ResponseType>*)>
+          func)
+      : TemplatedBidiStreamingHandler<
+            ServerSplitStreamer<RequestType, ResponseType>, false>(func) {}
+};
+
 // Handle unknown method by returning UNIMPLEMENTED error.
 class UnknownMethodHandler : public MethodHandler {
  public:
diff --git a/include/grpc++/impl/codegen/service_type.h b/include/grpc++/impl/codegen/service_type.h
index 72b2225312..de0862fce2 100644
--- a/include/grpc++/impl/codegen/service_type.h
+++ b/include/grpc++/impl/codegen/service_type.h
@@ -147,14 +147,16 @@ class Service {
     methods_[index].reset();
   }
 
-  void MarkMethodStreamedUnary(int index,
-                               MethodHandler* streamed_unary_method) {
+  void MarkMethodStreamed(int index,
+			  MethodHandler* streamed_method) {
     GPR_CODEGEN_ASSERT(methods_[index] && methods_[index]->handler() &&
-                       "Cannot mark an async or generic method Streamed Unary");
-    methods_[index]->SetHandler(streamed_unary_method);
+                       "Cannot mark an async or generic method Streamed");
+    methods_[index]->SetHandler(streamed_method);
 
     // From the server's point of view, streamed unary is a special
-    // case of BIDI_STREAMING that has 1 read and 1 write, in that order.
+    // case of BIDI_STREAMING that has 1 read and 1 write, in that order,
+    // and split server-side streaming is BIDI_STREAMING with 1 read and
+    // any number of writes, in that order
     methods_[index]->SetMethodType(::grpc::RpcMethod::BIDI_STREAMING);
   }
 
diff --git a/include/grpc++/impl/codegen/sync_stream.h b/include/grpc++/impl/codegen/sync_stream.h
index e3c5a919b1..f2b1e9a3ad 100644
--- a/include/grpc++/impl/codegen/sync_stream.h
+++ b/include/grpc++/impl/codegen/sync_stream.h
@@ -538,7 +538,7 @@ class ServerReaderWriter GRPC_FINAL : public ServerReaderWriterInterface<W, R> {
 /// the \a NextMessageSize method to determine an upper-bound on the size of
 /// the message.
 /// A key difference relative to streaming: ServerUnaryStreamer
-///  must have exactly 1 Read and exactly 1 Write, in that order, to function
+/// must have exactly 1 Read and exactly 1 Write, in that order, to function
 /// correctly. Otherwise, the RPC is in error.
 template <class RequestType, class ResponseType>
 class ServerUnaryStreamer GRPC_FINAL
@@ -577,6 +577,43 @@ class ServerUnaryStreamer GRPC_FINAL
   bool write_done_;
 };
 
+/// A class to represent a flow-controlled server-side streaming call.
+/// This is something of a hybrid between server-side and bidi streaming.
+/// This is invoked through a server-side streaming call on the client side,
+/// but the server responds to it as though it were a bidi streaming call that
+/// must first have exactly 1 Read and then any number of Writes.
+template <class RequestType, class ResponseType>
+class ServerSplitStreamer GRPC_FINAL
+    : public ServerReaderWriterInterface<ResponseType, RequestType> {
+ public:
+  ServerSplitStreamer(Call* call, ServerContext* ctx)
+      : body_(call, ctx), read_done_(false) {}
+
+  void SendInitialMetadata() GRPC_OVERRIDE { body_.SendInitialMetadata(); }
+
+  bool NextMessageSize(uint32_t* sz) GRPC_OVERRIDE {
+    return body_.NextMessageSize(sz);
+  }
+
+  bool Read(RequestType* request) GRPC_OVERRIDE {
+    if (read_done_) {
+      return false;
+    }
+    read_done_ = true;
+    return body_.Read(request);
+  }
+
+  using WriterInterface<ResponseType>::Write;
+  bool Write(const ResponseType& response,
+             const WriteOptions& options) GRPC_OVERRIDE {
+    return !read_done_ && body_.Write(response, options);
+  }
+
+ private:
+  internal::ServerReaderWriterBody<ResponseType, RequestType> body_;
+  bool read_done_;
+};
+
 }  // namespace grpc
 
 #endif  // GRPCXX_IMPL_CODEGEN_SYNC_STREAM_H
diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc
index d0c35ea1ab..d0f164e58d 100644
--- a/src/compiler/cpp_generator.cc
+++ b/src/compiler/cpp_generator.cc
@@ -605,7 +605,7 @@ void PrintHeaderServerMethodAsync(Printer *printer, const Method *method,
   printer->Print(*vars, "};\n");
 }
 
-void PrintHeaderServerMethodStreamedUnary(
+ void PrintHeaderServerMethodStreamedUnary(
     Printer *printer, const Method *method,
     std::map<grpc::string, grpc::string> *vars) {
   (*vars)["Method"] = method->name();
@@ -624,7 +624,7 @@ void PrintHeaderServerMethodStreamedUnary(
     printer->Indent();
     printer->Print(*vars,
                    "WithStreamedUnaryMethod_$Method$() {\n"
-                   "  ::grpc::Service::MarkMethodStreamedUnary($Idx$,\n"
+                   "  ::grpc::Service::MarkMethodStreamed($Idx$,\n"
                    "    new ::grpc::StreamedUnaryHandler< $Request$, "
                    "$Response$>(std::bind"
                    "(&WithStreamedUnaryMethod_$Method$<BaseClass>::"
@@ -656,6 +656,57 @@ void PrintHeaderServerMethodStreamedUnary(
   }
 }
 
+void PrintHeaderServerMethodSplitStreaming(
+    Printer *printer, const Method *method,
+    std::map<grpc::string, grpc::string> *vars) {
+  (*vars)["Method"] = method->name();
+  (*vars)["Request"] = method->input_type_name();
+  (*vars)["Response"] = method->output_type_name();
+  if (method->ServerOnlyStreaming()) {
+    printer->Print(*vars, "template <class BaseClass>\n");
+    printer->Print(*vars,
+                   "class WithSplitStreamingMethod_$Method$ : "
+                   "public BaseClass {\n");
+    printer->Print(
+        " private:\n"
+        "  void BaseClassMustBeDerivedFromService(const Service *service) "
+        "{}\n");
+    printer->Print(" public:\n");
+    printer->Indent();
+    printer->Print(*vars,
+                   "WithSplitStreamingMethod_$Method$() {\n"
+                   "  ::grpc::Service::MarkMethodStreamed($Idx$,\n"
+                   "    new ::grpc::SplitServerStreamingHandler< $Request$, "
+                   "$Response$>(std::bind"
+                   "(&WithSplitStreamingMethod_$Method$<BaseClass>::"
+                   "Streamed$Method$, this, std::placeholders::_1, "
+                   "std::placeholders::_2)));\n"
+                   "}\n");
+    printer->Print(*vars,
+                   "~WithSplitStreamingMethod_$Method$() GRPC_OVERRIDE {\n"
+                   "  BaseClassMustBeDerivedFromService(this);\n"
+                   "}\n");
+    printer->Print(
+        *vars,
+        "// disable regular version of this method\n"
+        "::grpc::Status $Method$("
+        "::grpc::ServerContext* context, const $Request$* request, "
+        "$Response$* response) GRPC_FINAL GRPC_OVERRIDE {\n"
+        "  abort();\n"
+        "  return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"
+        "}\n");
+    printer->Print(*vars,
+                   "// replace default version of method with split streamed\n"
+                   "virtual ::grpc::Status Streamed$Method$("
+                   "::grpc::ServerContext* context, "
+                   "::grpc::ServerSplitStreamer< "
+                   "$Request$,$Response$>* server_split_streamer)"
+                   " = 0;\n");
+    printer->Outdent();
+    printer->Print(*vars, "};\n");
+  }
+}
+
 void PrintHeaderServerMethodGeneric(
     Printer *printer, const Method *method,
     std::map<grpc::string, grpc::string> *vars) {
@@ -844,6 +895,48 @@ void PrintHeaderService(Printer *printer, const Service *service,
   }
   printer->Print(" StreamedUnaryService;\n");
 
+  // Server side - controlled server-side streaming
+  for (int i = 0; i < service->method_count(); ++i) {
+    (*vars)["Idx"] = as_string(i);
+    PrintHeaderServerMethodSplitStreaming(printer, service->method(i).get(),
+					  vars);
+  }
+
+  printer->Print("typedef ");
+  for (int i = 0; i < service->method_count(); ++i) {
+    (*vars)["method_name"] = service->method(i).get()->name();
+    if (service->method(i)->ServerOnlyStreaming()) {
+      printer->Print(*vars, "WithSplitStreamingMethod_$method_name$<");
+    }
+  }
+  printer->Print("Service");
+  for (int i = 0; i < service->method_count(); ++i) {
+    if (service->method(i)->ServerOnlyStreaming()) {
+      printer->Print(" >");
+    }
+  }
+  printer->Print(" SplitStreamedService;\n");
+
+  // Server side - typedef for controlled both unary and server-side streaming
+  printer->Print("typedef ");
+  for (int i = 0; i < service->method_count(); ++i) {
+    (*vars)["method_name"] = service->method(i).get()->name();
+    if (service->method(i)->ServerOnlyStreaming()) {
+      printer->Print(*vars, "WithSplitStreamingMethod_$method_name$<");
+    }
+    if (service->method(i)->NoStreaming()) {
+      printer->Print(*vars, "WithStreamedUnaryMethod_$method_name$<");
+    }
+  }
+  printer->Print("Service");
+  for (int i = 0; i < service->method_count(); ++i) {
+    if (service->method(i)->NoStreaming() ||
+	service->method(i)->ServerOnlyStreaming()) {
+      printer->Print(" >");
+    }
+  }
+  printer->Print(" StreamedService;\n");
+
   printer->Outdent();
   printer->Print("};\n");
   printer->Print(service->GetTrailingComments().c_str());
-- 
GitLab