From 8c8d0aa1d881fbcf393a73f99b86ed29a866f8ff Mon Sep 17 00:00:00 2001
From: Craig Tiller <ctiller@google.com>
Date: Thu, 12 Feb 2015 11:38:36 -0800
Subject: [PATCH] Async API progress

---
 include/grpc++/impl/service_type.h | 18 +++++--
 include/grpc++/server.h            |  3 +-
 include/grpc++/server_builder.h    |  7 ++-
 src/compiler/cpp_generator.cc      | 67 +++++++++++++++++---------
 src/cpp/server/server.cc           | 75 +++++++++++++++++++-----------
 src/cpp/server/server_builder.cc   | 22 +++++----
 6 files changed, 129 insertions(+), 63 deletions(-)

diff --git a/include/grpc++/impl/service_type.h b/include/grpc++/impl/service_type.h
index 0684f322d8..30654553ad 100644
--- a/include/grpc++/impl/service_type.h
+++ b/include/grpc++/impl/service_type.h
@@ -37,19 +37,29 @@
 namespace grpc {
 
 class RpcService;
+class Server;
 
 class SynchronousService {
  public:
   virtual ~SynchronousService() {}
-  virtual RpcService *service() = 0;
+  virtual RpcService* service() = 0;
 };
 
 class AsynchronousService {
  public:
-  virtual ~AsynchronousService() {}
-  virtual RpcService *service() = 0;
+  AsynchronousService(CompletionQueue* cq, const char** method_names, size_t method_count) : cq_(cq), method_names_(method_names), method_count_(method_count) {}
+
+  CompletionQueue* completion_queue() const { return cq_; }
+
+ private:
+  friend class Server;
+  CompletionQueue* const cq_;
+  Server* server_ = nullptr;
+  const char**const method_names_;
+  size_t method_count_;
+  std::vector<void*> request_args_;
 };
 
 }  // namespace grpc
 
-#endif // __GRPCPP_IMPL_SERVICE_TYPE_H__
\ No newline at end of file
+#endif  // __GRPCPP_IMPL_SERVICE_TYPE_H__
\ No newline at end of file
diff --git a/include/grpc++/server.h b/include/grpc++/server.h
index 98f3f17197..77aac75076 100644
--- a/include/grpc++/server.h
+++ b/include/grpc++/server.h
@@ -53,7 +53,7 @@ class Message;
 }  // namespace google
 
 namespace grpc {
-class AsyncServerContext;
+class AsynchronousService;
 class RpcService;
 class RpcServiceMethod;
 class ServerCredentials;
@@ -79,6 +79,7 @@ class Server final : private CallHook {
   // Register a service. This call does not take ownership of the service.
   // The service must exist for the lifetime of the Server instance.
   bool RegisterService(RpcService* service);
+  bool RegisterAsyncService(AsynchronousService* service);
   // Add a listening port. Can be called multiple times.
   int AddPort(const grpc::string& addr);
   // Start the server.
diff --git a/include/grpc++/server_builder.h b/include/grpc++/server_builder.h
index 8b4c81bc87..a550a53afb 100644
--- a/include/grpc++/server_builder.h
+++ b/include/grpc++/server_builder.h
@@ -42,6 +42,7 @@
 namespace grpc {
 
 class AsynchronousService;
+class CompletionQueue;
 class RpcService;
 class Server;
 class ServerCredentials;
@@ -57,7 +58,11 @@ class ServerBuilder {
   // BuildAndStart().
   void RegisterService(SynchronousService* service);
 
-  void RegisterAsyncService(AsynchronousService *service);
+  // Register an asynchronous service. New calls will be delevered to cq.
+  // This call does not take ownership of the service or completion queue.
+  // The service and completion queuemust exist for the lifetime of the Server
+  // instance returned by BuildAndStart().
+  void RegisterAsyncService(AsynchronousService* service);
 
   // Add a listening port. Can be called multiple times.
   void AddPort(const grpc::string& addr);
diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc
index e29cfadcef..4a31ff949e 100644
--- a/src/compiler/cpp_generator.cc
+++ b/src/compiler/cpp_generator.cc
@@ -41,10 +41,18 @@
 #include <google/protobuf/descriptor.pb.h>
 #include <google/protobuf/io/printer.h>
 #include <google/protobuf/io/zero_copy_stream_impl_lite.h>
+#include <sstream>
 
 namespace grpc_cpp_generator {
 namespace {
 
+template <class T>
+std::string as_string(T x) {
+  std::ostringstream out;
+  out << x;
+  return out.str();
+}
+
 bool NoStreaming(const google::protobuf::MethodDescriptor *method) {
   return !method->client_streaming() && !method->server_streaming();
 }
@@ -113,6 +121,7 @@ std::string GetHeaderIncludes(const google::protobuf::FileDescriptor *file) {
       "#include <grpc++/status.h>\n"
       "\n"
       "namespace grpc {\n"
+      "class CompletionQueue;\n"
       "class ChannelInterface;\n"
       "class RpcService;\n"
       "class ServerContext;\n";
@@ -325,16 +334,13 @@ void PrintHeaderService(google::protobuf::io::Printer *printer,
       "class AsyncService final : public ::grpc::AsynchronousService {\n"
       " public:\n");
   printer->Indent();
-  printer->Print("AsyncService() : service_(nullptr) {}\n");
+  (*vars)["MethodCount"] = as_string(service->method_count());
+  printer->Print("explicit AsyncService(::grpc::CompletionQueue* cq);\n");
   printer->Print("~AsyncService();\n");
   for (int i = 0; i < service->method_count(); ++i) {
     PrintHeaderServerMethodAsync(printer, service->method(i), vars);
   }
-  printer->Print("::grpc::RpcService* service() override;\n");
   printer->Outdent();
-  printer->Print(
-      " private:\n"
-      "  ::grpc::RpcService* service_;\n");
   printer->Print("};\n");
 
   printer->Outdent();
@@ -369,7 +375,7 @@ void PrintSourceClientMethod(google::protobuf::io::Printer *printer,
                    "const $Request$& request, $Response$* response) {\n");
     printer->Print(*vars,
                    "return ::grpc::BlockingUnaryCall(channel(),"
-                   "::grpc::RpcMethod(\"/$Package$$Service$/$Method$\"), "
+                   "::grpc::RpcMethod($Service$_method_names[$Idx$]), "
                    "context, request, response);\n"
                    "}\n\n");
   } else if (ClientOnlyStreaming(method)) {
@@ -380,7 +386,7 @@ void PrintSourceClientMethod(google::protobuf::io::Printer *printer,
     printer->Print(*vars,
                    "  return new ::grpc::ClientWriter< $Request$>("
                    "channel(),"
-                   "::grpc::RpcMethod(\"/$Package$$Service$/$Method$\", "
+                   "::grpc::RpcMethod($Service$_method_names[$Idx$], "
                    "::grpc::RpcMethod::RpcType::CLIENT_STREAMING), "
                    "context, response);\n"
                    "}\n\n");
@@ -392,7 +398,7 @@ void PrintSourceClientMethod(google::protobuf::io::Printer *printer,
     printer->Print(*vars,
                    "  return new ::grpc::ClientReader< $Response$>("
                    "channel(),"
-                   "::grpc::RpcMethod(\"/$Package$$Service$/$Method$\", "
+                   "::grpc::RpcMethod($Service$_method_names[$Idx$], "
                    "::grpc::RpcMethod::RpcType::SERVER_STREAMING), "
                    "context, *request);\n"
                    "}\n\n");
@@ -405,7 +411,7 @@ void PrintSourceClientMethod(google::protobuf::io::Printer *printer,
         *vars,
         "  return new ::grpc::ClientReaderWriter< $Request$, $Response$>("
         "channel(),"
-        "::grpc::RpcMethod(\"/$Package$$Service$/$Method$\", "
+        "::grpc::RpcMethod($Service$_method_names[$Idx$], "
         "::grpc::RpcMethod::RpcType::BIDI_STREAMING), "
         "context);\n"
         "}\n\n");
@@ -462,9 +468,10 @@ void PrintSourceServerMethod(google::protobuf::io::Printer *printer,
   }
 }
 
-void PrintSourceServerAsyncMethod(google::protobuf::io::Printer *printer,
-                             const google::protobuf::MethodDescriptor *method,
-                             std::map<std::string, std::string> *vars) {
+void PrintSourceServerAsyncMethod(
+    google::protobuf::io::Printer *printer,
+    const google::protobuf::MethodDescriptor *method,
+    std::map<std::string, std::string> *vars) {
   (*vars)["Method"] = method->name();
   (*vars)["Request"] =
       grpc_cpp_generator::ClassName(method->input_type(), true);
@@ -494,11 +501,12 @@ void PrintSourceServerAsyncMethod(google::protobuf::io::Printer *printer,
                    "::grpc::CompletionQueue* cq, void* tag) {\n");
     printer->Print("}\n\n");
   } else if (BidiStreaming(method)) {
-    printer->Print(*vars,
-                   "void $Service$::AsyncService::Request$Method$("
-                   "::grpc::ServerContext* context, "
-                   "::grpc::ServerAsyncReaderWriter< $Response$, $Request$>* stream, "
-                   "::grpc::CompletionQueue* cq, void *tag) {\n");
+    printer->Print(
+        *vars,
+        "void $Service$::AsyncService::Request$Method$("
+        "::grpc::ServerContext* context, "
+        "::grpc::ServerAsyncReaderWriter< $Response$, $Request$>* stream, "
+        "::grpc::CompletionQueue* cq, void *tag) {\n");
     printer->Print("}\n\n");
   }
 }
@@ -507,6 +515,14 @@ void PrintSourceService(google::protobuf::io::Printer *printer,
                         const google::protobuf::ServiceDescriptor *service,
                         std::map<std::string, std::string> *vars) {
   (*vars)["Service"] = service->name();
+
+  printer->Print(*vars, "static const char* $Service$_method_names[] = {\n");
+  for (int i = 0; i < service->method_count(); ++i) {
+    (*vars)["Method"] = service->method(i)->name();
+    printer->Print(*vars, "  \"/$Package$$Service$/$Method$\",\n");
+  }
+  printer->Print(*vars, "};\n\n");
+
   printer->Print(
       *vars,
       "$Service$::Stub* $Service$::NewStub("
@@ -516,9 +532,17 @@ void PrintSourceService(google::protobuf::io::Printer *printer,
       "  return stub;\n"
       "};\n\n");
   for (int i = 0; i < service->method_count(); ++i) {
+    (*vars)["Idx"] = as_string(i);
     PrintSourceClientMethod(printer, service->method(i), vars);
   }
 
+  (*vars)["MethodCount"] = as_string(service->method_count());
+  printer->Print(
+      *vars,
+      "$Service$::AsyncService::AsyncService(::grpc::CompletionQueue* cq) : "
+      "::grpc::AsynchronousService(cq, $Service$_method_names, $MethodCount$) "
+      "{}\n\n");
+
   printer->Print(*vars,
                  "$Service$::Service::~Service() {\n"
                  "  delete service_;\n"
@@ -537,6 +561,7 @@ void PrintSourceService(google::protobuf::io::Printer *printer,
   printer->Print("service_ = new ::grpc::RpcService();\n");
   for (int i = 0; i < service->method_count(); ++i) {
     const google::protobuf::MethodDescriptor *method = service->method(i);
+    (*vars)["Idx"] = as_string(i);
     (*vars)["Method"] = method->name();
     (*vars)["Request"] =
         grpc_cpp_generator::ClassName(method->input_type(), true);
@@ -546,7 +571,7 @@ void PrintSourceService(google::protobuf::io::Printer *printer,
       printer->Print(
           *vars,
           "service_->AddMethod(new ::grpc::RpcServiceMethod(\n"
-          "    \"/$Package$$Service$/$Method$\",\n"
+          "    $Service$_method_names[$Idx$],\n"
           "    ::grpc::RpcMethod::NORMAL_RPC,\n"
           "    new ::grpc::RpcMethodHandler< $Service$::Service, $Request$, "
           "$Response$>(\n"
@@ -558,7 +583,7 @@ void PrintSourceService(google::protobuf::io::Printer *printer,
       printer->Print(
           *vars,
           "service_->AddMethod(new ::grpc::RpcServiceMethod(\n"
-          "    \"/$Package$$Service$/$Method$\",\n"
+          "    $Service$_method_names[$Idx$],\n"
           "    ::grpc::RpcMethod::CLIENT_STREAMING,\n"
           "    new ::grpc::ClientStreamingHandler< "
           "$Service$::Service, $Request$, $Response$>(\n"
@@ -571,7 +596,7 @@ void PrintSourceService(google::protobuf::io::Printer *printer,
       printer->Print(
           *vars,
           "service_->AddMethod(new ::grpc::RpcServiceMethod(\n"
-          "    \"/$Package$$Service$/$Method$\",\n"
+          "    $Service$_method_names[$Idx$],\n"
           "    ::grpc::RpcMethod::SERVER_STREAMING,\n"
           "    new ::grpc::ServerStreamingHandler< "
           "$Service$::Service, $Request$, $Response$>(\n"
@@ -584,7 +609,7 @@ void PrintSourceService(google::protobuf::io::Printer *printer,
       printer->Print(
           *vars,
           "service_->AddMethod(new ::grpc::RpcServiceMethod(\n"
-          "    \"/$Package$$Service$/$Method$\",\n"
+          "    $Service$_method_names[$Idx$],\n"
           "    ::grpc::RpcMethod::BIDI_STREAMING,\n"
           "    new ::grpc::BidiStreamingHandler< "
           "$Service$::Service, $Request$, $Response$>(\n"
diff --git a/src/cpp/server/server.cc b/src/cpp/server/server.cc
index 90a2863b0c..20dd135a86 100644
--- a/src/cpp/server/server.cc
+++ b/src/cpp/server/server.cc
@@ -39,6 +39,7 @@
 #include <grpc/support/log.h>
 #include <grpc++/completion_queue.h>
 #include <grpc++/impl/rpc_service_method.h>
+#include <grpc++/impl/service_type.h>
 #include <grpc++/server_context.h>
 #include <grpc++/server_credentials.h>
 #include <grpc++/thread_pool_interface.h>
@@ -47,8 +48,8 @@
 
 namespace grpc {
 
-Server::Server(ThreadPoolInterface *thread_pool, bool thread_pool_owned,
-               ServerCredentials *creds)
+Server::Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned,
+               ServerCredentials* creds)
     : started_(false),
       shutdown_(false),
       num_running_cb_(0),
@@ -56,7 +57,8 @@ Server::Server(ThreadPoolInterface *thread_pool, bool thread_pool_owned,
       thread_pool_owned_(thread_pool_owned),
       secure_(creds != nullptr) {
   if (creds) {
-    server_ = grpc_secure_server_create(creds->GetRawCreds(), cq_.cq(), nullptr);
+    server_ =
+        grpc_secure_server_create(creds->GetRawCreds(), cq_.cq(), nullptr);
   } else {
     server_ = grpc_server_create(cq_.cq(), nullptr);
   }
@@ -81,10 +83,11 @@ Server::~Server() {
   }
 }
 
-bool Server::RegisterService(RpcService *service) {
+bool Server::RegisterService(RpcService* service) {
   for (int i = 0; i < service->GetMethodCount(); ++i) {
-    RpcServiceMethod *method = service->GetMethod(i);
-    void *tag = grpc_server_register_method(server_, method->name(), nullptr, cq_.cq());
+    RpcServiceMethod* method = service->GetMethod(i);
+    void* tag =
+        grpc_server_register_method(server_, method->name(), nullptr, cq_.cq());
     if (!tag) {
       gpr_log(GPR_DEBUG, "Attempt to register %s multiple times",
               method->name());
@@ -95,7 +98,24 @@ bool Server::RegisterService(RpcService *service) {
   return true;
 }
 
-int Server::AddPort(const grpc::string &addr) {
+bool Server::RegisterAsyncService(AsynchronousService* service) {
+  GPR_ASSERT(service->server_ == nullptr && "Can only register an asynchronous service against one server.");
+  service->server_ = this;
+  service->request_args_.reserve(service->method_count_);
+  for (size_t i = 0; i < service->method_count_; ++i) {
+    void* tag = grpc_server_register_method(server_, service->method_names_[i], nullptr,
+                                            service->completion_queue()->cq());
+    if (!tag) {
+      gpr_log(GPR_DEBUG, "Attempt to register %s multiple times",
+              service->method_names_[i]);
+      return false;
+    }
+    service->request_args_.push_back(tag);
+  }
+  return true;
+}
+
+int Server::AddPort(const grpc::string& addr) {
   GPR_ASSERT(!started_);
   if (secure_) {
     return grpc_server_add_secure_http2_port(server_, addr.c_str());
@@ -106,7 +126,7 @@ int Server::AddPort(const grpc::string &addr) {
 
 class Server::MethodRequestData final : public CompletionQueueTag {
  public:
-  MethodRequestData(RpcServiceMethod *method, void *tag)
+  MethodRequestData(RpcServiceMethod* method, void* tag)
       : method_(method),
         tag_(tag),
         has_request_payload_(method->method_type() == RpcMethod::NORMAL_RPC ||
@@ -118,33 +138,33 @@ class Server::MethodRequestData final : public CompletionQueueTag {
     grpc_metadata_array_init(&request_metadata_);
   }
 
-  static MethodRequestData *Wait(CompletionQueue *cq, bool *ok) {
-    void *tag = nullptr;
+  static MethodRequestData* Wait(CompletionQueue* cq, bool* ok) {
+    void* tag = nullptr;
     *ok = false;
     if (!cq->Next(&tag, ok)) {
       return nullptr;
     }
-    auto *mrd = static_cast<MethodRequestData *>(tag);
+    auto* mrd = static_cast<MethodRequestData*>(tag);
     GPR_ASSERT(mrd->in_flight_);
     return mrd;
   }
 
-  void Request(grpc_server *server) {
+  void Request(grpc_server* server) {
     GPR_ASSERT(!in_flight_);
     in_flight_ = true;
     cq_ = grpc_completion_queue_create();
     GPR_ASSERT(GRPC_CALL_OK ==
                grpc_server_request_registered_call(
                    server, tag_, &call_, &deadline_, &request_metadata_,
-                   has_request_payload_ ? &request_payload_ : nullptr, 
-                   cq_, this));
+                   has_request_payload_ ? &request_payload_ : nullptr, cq_,
+                   this));
   }
 
-  void FinalizeResult(void **tag, bool *status) override {}
+  void FinalizeResult(void** tag, bool* status) override {}
 
   class CallData {
    public:
-    explicit CallData(Server *server, MethodRequestData *mrd)
+    explicit CallData(Server* server, MethodRequestData* mrd)
         : cq_(mrd->cq_),
           call_(mrd->call_, server, &cq_),
           ctx_(mrd->deadline_, mrd->request_metadata_.metadata,
@@ -196,21 +216,21 @@ class Server::MethodRequestData final : public CompletionQueueTag {
     ServerContext ctx_;
     const bool has_request_payload_;
     const bool has_response_payload_;
-    grpc_byte_buffer *request_payload_;
-    RpcServiceMethod *const method_;
+    grpc_byte_buffer* request_payload_;
+    RpcServiceMethod* const method_;
   };
 
  private:
-  RpcServiceMethod *const method_;
-  void *const tag_;
+  RpcServiceMethod* const method_;
+  void* const tag_;
   bool in_flight_ = false;
   const bool has_request_payload_;
   const bool has_response_payload_;
-  grpc_call *call_;
+  grpc_call* call_;
   gpr_timespec deadline_;
   grpc_metadata_array request_metadata_;
-  grpc_byte_buffer *request_payload_;
-  grpc_completion_queue *cq_;
+  grpc_byte_buffer* request_payload_;
+  grpc_completion_queue* cq_;
 };
 
 bool Server::Start() {
@@ -220,7 +240,7 @@ bool Server::Start() {
 
   // Start processing rpcs.
   if (!methods_.empty()) {
-    for (auto &m : methods_) {
+    for (auto& m : methods_) {
       m.Request(server_);
     }
 
@@ -246,14 +266,13 @@ void Server::Shutdown() {
   }
 }
 
-void Server::PerformOpsOnCall(CallOpBuffer *buf, Call *call) {
+void Server::PerformOpsOnCall(CallOpBuffer* buf, Call* call) {
   static const size_t MAX_OPS = 8;
   size_t nops = MAX_OPS;
   grpc_op ops[MAX_OPS];
   buf->FillOps(ops, &nops);
   GPR_ASSERT(GRPC_CALL_OK ==
-             grpc_call_start_batch(call->call(), ops, nops,
-                                   buf));
+             grpc_call_start_batch(call->call(), ops, nops, buf));
 }
 
 void Server::ScheduleCallback() {
@@ -267,7 +286,7 @@ void Server::ScheduleCallback() {
 void Server::RunRpc() {
   // Wait for one more incoming rpc.
   bool ok;
-  auto *mrd = MethodRequestData::Wait(&cq_, &ok);
+  auto* mrd = MethodRequestData::Wait(&cq_, &ok);
   if (mrd) {
     ScheduleCallback();
     if (ok) {
diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc
index d6bcb9313a..dd23e929b1 100644
--- a/src/cpp/server/server_builder.cc
+++ b/src/cpp/server/server_builder.cc
@@ -43,25 +43,25 @@ namespace grpc {
 
 ServerBuilder::ServerBuilder() {}
 
-void ServerBuilder::RegisterService(SynchronousService *service) {
+void ServerBuilder::RegisterService(SynchronousService* service) {
   services_.push_back(service->service());
 }
 
-void ServerBuilder::RegisterAsyncService(AsynchronousService *service) {
+void ServerBuilder::RegisterAsyncService(AsynchronousService* service) {
   async_services_.push_back(service);
 }
 
-void ServerBuilder::AddPort(const grpc::string &addr) {
+void ServerBuilder::AddPort(const grpc::string& addr) {
   ports_.push_back(addr);
 }
 
 void ServerBuilder::SetCredentials(
-    const std::shared_ptr<ServerCredentials> &creds) {
+    const std::shared_ptr<ServerCredentials>& creds) {
   GPR_ASSERT(!creds_);
   creds_ = creds;
 }
 
-void ServerBuilder::SetThreadPool(ThreadPoolInterface *thread_pool) {
+void ServerBuilder::SetThreadPool(ThreadPoolInterface* thread_pool) {
   thread_pool_ = thread_pool;
 }
 
@@ -77,13 +77,19 @@ std::unique_ptr<Server> ServerBuilder::BuildAndStart() {
     thread_pool_ = new ThreadPool(cores);
     thread_pool_owned = true;
   }
-  std::unique_ptr<Server> server(new Server(thread_pool_, thread_pool_owned, creds_.get()));
-  for (auto *service : services_) {
+  std::unique_ptr<Server> server(
+      new Server(thread_pool_, thread_pool_owned, creds_.get()));
+  for (auto* service : services_) {
     if (!server->RegisterService(service)) {
       return nullptr;
     }
   }
-  for (auto &port : ports_) {
+  for (auto* service : async_services_) {
+    if (!server->RegisterAsyncService(service)) {
+      return nullptr;
+    }
+  }
+  for (auto& port : ports_) {
     if (!server->AddPort(port)) {
       return nullptr;
     }
-- 
GitLab