From b6df94ad220b9e48d306150184848db614cc3506 Mon Sep 17 00:00:00 2001
From: vjpai <vpai@google.com>
Date: Mon, 30 Nov 2015 15:52:50 -0800
Subject: [PATCH] WIP

---
 Makefile                             |   2 +-
 include/grpc++/support/byte_buffer.h |   3 +
 src/cpp/util/byte_buffer.cc          |   6 ++
 test/cpp/qps/client.h                |  83 ++++++++++++------
 test/cpp/qps/client_async.cc         | 120 +++++++++++++++++++++++++--
 test/cpp/qps/server.h                |   3 +-
 6 files changed, 179 insertions(+), 38 deletions(-)

diff --git a/Makefile b/Makefile
index 9d94ee8599..b48974379b 100644
--- a/Makefile
+++ b/Makefile
@@ -3077,7 +3077,7 @@ test_c: buildtests_c
 	$(Q) $(BINDIR)/$(CONFIG)/h2_full+poll_ping_pong_streaming_nosec_test || ( echo test h2_full+poll_ping_pong_streaming_nosec_test failed ; exit 1 )
 	$(E) "[RUN]     Testing h2_full+poll_registered_call_nosec_test"
 	$(Q) $(BINDIR)/$(CONFIG)/h2_full+poll_registered_call_nosec_test || ( echo test h2_full+poll_registered_call_nosec_test failed ; exit 1 )
-	$(E) "[RUN]     Testing h2_full+poll_request_with_flags_nosec_test"
+	$(E) "[RUN]     Testing h2_full+poll_riequest_with_flags_nosec_test"
 	$(Q) $(BINDIR)/$(CONFIG)/h2_full+poll_request_with_flags_nosec_test || ( echo test h2_full+poll_request_with_flags_nosec_test failed ; exit 1 )
 	$(E) "[RUN]     Testing h2_full+poll_request_with_payload_nosec_test"
 	$(Q) $(BINDIR)/$(CONFIG)/h2_full+poll_request_with_payload_nosec_test || ( echo test h2_full+poll_request_with_payload_nosec_test failed ; exit 1 )
diff --git a/include/grpc++/support/byte_buffer.h b/include/grpc++/support/byte_buffer.h
index 9d19b07708..6f29bdfcd5 100644
--- a/include/grpc++/support/byte_buffer.h
+++ b/include/grpc++/support/byte_buffer.h
@@ -66,6 +66,9 @@ class ByteBuffer GRPC_FINAL {
   /// Buffer size in bytes.
   size_t Length() const;
 
+  /// Move contents from \a bbuf and clear \a bbuf
+  void MoveFrom(ByteBuffer* bbuf);
+
  private:
   friend class SerializationTraits<ByteBuffer, void>;
 
diff --git a/src/cpp/util/byte_buffer.cc b/src/cpp/util/byte_buffer.cc
index 755234d7e8..4d73599542 100644
--- a/src/cpp/util/byte_buffer.cc
+++ b/src/cpp/util/byte_buffer.cc
@@ -79,4 +79,10 @@ size_t ByteBuffer::Length() const {
   }
 }
 
+void ByteBuffer::MoveFrom(ByteBuffer* bbuf) {
+  Clear(); // in case we already had something, but we shouldn't use this then
+  buffer_ = bbuf->buffer_;
+  bbuf->buffer_ = nullptr;
+}
+
 }  // namespace grpc
diff --git a/test/cpp/qps/client.h b/test/cpp/qps/client.h
index f4400692fe..30a8030f51 100644
--- a/test/cpp/qps/client.h
+++ b/test/cpp/qps/client.h
@@ -66,36 +66,66 @@ namespace testing {
 typedef std::chrono::high_resolution_clock grpc_time_source;
 typedef std::chrono::time_point<grpc_time_source> grpc_time;
 
+namespace ClientRequestCreation {
+template <class RequestType>
+void CreateRequest(RequestType *req, const PayloadConfig&) {
+  // this template must be specialized
+  // fail with an assertion rather than a compile-time
+  // check since these only happen at the beginning anyway
+  GPR_ASSERT(false);
+}
+    
+template <>
+void CreateRequest<SimpleRequest>(SimpleRequest *req, const PayloadConfig& payload_config) {
+  if (payload_config.has_bytebuf_params()) {
+    GPR_ASSERT(false);  // not appropriate for this specialization
+  } else if (payload_config.has_simple_params()) {
+    req->set_response_type(grpc::testing::PayloadType::COMPRESSABLE);
+    req->set_response_size(payload_config.simple_params().resp_size());
+    req->mutable_payload()->set_type(grpc::testing::PayloadType::COMPRESSABLE);
+    int size = payload_config.simple_params().req_size();
+    std::unique_ptr<char[]> body(new char[size]);
+    req->mutable_payload()->set_body(body.get(), size);
+  } else if (payload_config.has_complex_params()) {
+    GPR_ASSERT(false);  // not appropriate for this specialization
+  } else {
+    // default should be simple proto without payloads
+    req->set_response_type(grpc::testing::PayloadType::COMPRESSABLE);
+    req->set_response_size(0);
+    req->mutable_payload()->set_type(grpc::testing::PayloadType::COMPRESSABLE);
+  }
+}
+template <>
+void CreateRequest<ByteBuffer>(ByteBuffer *req, const PayloadConfig& payload_config) {
+  if (payload_config.has_bytebuf_params()) {
+    if (payload_config.req_size() > 0) {
+      std::unique_ptr<char> buf(new char[payload_config.req_size()]);
+      gpr_slice_from_copied_buffer(buf.get(), payload_config.req_size());
+      Slice slice(s, Slice::STEAL_REF);
+      std::unique_ptr<ByteBuffer> bbuf(new ByteBuffer(&slice, 1));
+      req->MoveFrom(bbuf.get());
+    } else {
+      GPR_ASSERT(false);  // not appropriate for this specialization
+    }
+  }
+}
+}
+ 
+template <class StubType, class RequestType>
 class Client {
  public:
-  explicit Client(const ClientConfig& config)
+   Client(const ClientConfig& config,
+	  std::function<std::unique_ptr<StubType>(std::shared_ptr<Channel>)> create_stub)
       : channels_(config.client_channels()),
+        create_stub_(create_stub),
         timer_(new Timer),
         interarrival_timer_() {
     for (int i = 0; i < config.client_channels(); i++) {
       channels_[i].init(config.server_targets(i % config.server_targets_size()),
                         config);
     }
-    if (config.payload_config().has_bytebuf_params()) {
-      GPR_ASSERT(false);  // not yet implemented
-    } else if (config.payload_config().has_simple_params()) {
-      request_.set_response_type(grpc::testing::PayloadType::COMPRESSABLE);
-      request_.set_response_size(
-          config.payload_config().simple_params().resp_size());
-      request_.mutable_payload()->set_type(
-          grpc::testing::PayloadType::COMPRESSABLE);
-      int size = config.payload_config().simple_params().req_size();
-      std::unique_ptr<char[]> body(new char[size]);
-      request_.mutable_payload()->set_body(body.get(), size);
-    } else if (config.payload_config().has_complex_params()) {
-      GPR_ASSERT(false);  // not yet implemented
-    } else {
-      // default should be simple proto without payloads
-      request_.set_response_type(grpc::testing::PayloadType::COMPRESSABLE);
-      request_.set_response_size(0);
-      request_.mutable_payload()->set_type(
-          grpc::testing::PayloadType::COMPRESSABLE);
-    }
+
+    ClientRequestCreation::CreateRequest<RequestType>(&request_, config.payload_config());
   }
   virtual ~Client() {}
 
@@ -134,7 +164,7 @@ class Client {
   }
 
  protected:
-  SimpleRequest request_;
+  RequestType request_;
   bool closed_loop_;
 
   class ClientChannelInfo {
@@ -154,16 +184,17 @@ class Client {
           target, config.security_params().server_host_override(),
           config.has_security_params(),
           !config.security_params().use_test_ca());
-      stub_ = BenchmarkService::NewStub(channel_);
+      stub_ = create_stub_(channel_);
     }
     Channel* get_channel() { return channel_.get(); }
-    BenchmarkService::Stub* get_stub() { return stub_.get(); }
+    StubType* get_stub() { return stub_.get(); }
 
    private:
     std::shared_ptr<Channel> channel_;
-    std::unique_ptr<BenchmarkService::Stub> stub_;
+    std::unique_ptr<StubType> stub_;
   };
   std::vector<ClientChannelInfo> channels_;
+  std::function<std::unique_ptr<StubType>(std::shared_ptr<Channel>)> create_stub_;
 
   void StartThreads(size_t num_threads) {
     for (size_t i = 0; i < num_threads; i++) {
@@ -306,7 +337,7 @@ class Client {
     size_t idx_;
     std::thread impl_;
   };
-
+  
   std::vector<std::unique_ptr<Thread>> threads_;
   std::unique_ptr<Timer> timer_;
 
diff --git a/test/cpp/qps/client_async.cc b/test/cpp/qps/client_async.cc
index 9594179822..c05774c410 100644
--- a/test/cpp/qps/client_async.cc
+++ b/test/cpp/qps/client_async.cc
@@ -147,13 +147,14 @@ class ClientRpcContextUnaryImpl : public ClientRpcContext {
 
 typedef std::forward_list<ClientRpcContext*> context_list;
 
-class AsyncClient : public Client {
+template <class StubType, class RequestType>
+class AsyncClient : public Client<StubType, RequestType> {
  public:
-  explicit AsyncClient(
+  AsyncClient(
       const ClientConfig& config,
-      std::function<ClientRpcContext*(int, BenchmarkService::Stub*,
-                                      const SimpleRequest&)> setup_ctx)
-      : Client(config),
+      std::function<ClientRpcContext*(int, StubType*, const RequestType&)> setup_ctx,
+      std::function<std::unique_ptr<StubType>(std::shared_ptr<Channel>)> create_stub)
+    : Client(config, create_stub),
         channel_lock_(new std::mutex[config.client_channels()]),
         contexts_(config.client_channels()),
         max_outstanding_per_channel_(config.outstanding_rpcs_per_channel()),
@@ -343,10 +344,10 @@ class AsyncClient : public Client {
   int pref_channel_inc_;
 };
 
-class AsyncUnaryClient GRPC_FINAL : public AsyncClient {
+class AsyncUnaryClient GRPC_FINAL : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
  public:
   explicit AsyncUnaryClient(const ClientConfig& config)
-      : AsyncClient(config, SetupCtx) {
+    : AsyncClient(config, SetupCtx, BenchmarkService::NewStub) {
     StartThreads(config.async_client_threads());
   }
   ~AsyncUnaryClient() GRPC_OVERRIDE { EndThreads(); }
@@ -437,10 +438,10 @@ class ClientRpcContextStreamingImpl : public ClientRpcContext {
       stream_;
 };
 
-class AsyncStreamingClient GRPC_FINAL : public AsyncClient {
+class AsyncStreamingClient GRPC_FINAL : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
  public:
   explicit AsyncStreamingClient(const ClientConfig& config)
-      : AsyncClient(config, SetupCtx) {
+      : AsyncClient(config, SetupCtx, BenchmarkService::NewStub) {
     // async streaming currently only supports closed loop
     GPR_ASSERT(closed_loop_);
 
@@ -467,12 +468,113 @@ class AsyncStreamingClient GRPC_FINAL : public AsyncClient {
   }
 };
 
+class ClientGenericRpcContextStreamingImpl : public ClientRpcContext {
+ public:
+  ClientGenericRpcContextStreamingImpl(
+      int channel_id, grpc::GenericStub* stub, const ByteBuffer& req,
+      std::function<std::unique_ptr<
+          grpc::GenericClientAsyncReaderWriter>(
+          grpc::GenericStub*, grpc::ClientContext*, const grpc::string& method_name,
+	  CompletionQueue*, void*)> start_req,
+      std::function<void(grpc::Status, ByteBuffer*)> on_done)
+      : ClientRpcContext(channel_id),
+        context_(),
+        stub_(stub),
+        req_(req),
+        response_(),
+        next_state_(&ClientGenericRpcContextStreamingImpl::ReqSent),
+        callback_(on_done),
+        start_req_(start_req),
+        start_(Timer::Now()) {}
+  ~ClientGenericRpcContextStreamingImpl() GRPC_OVERRIDE {}
+  bool RunNextState(bool ok, Histogram* hist) GRPC_OVERRIDE {
+    return (this->*next_state_)(ok, hist);
+  }
+  ClientRpcContext* StartNewClone() GRPC_OVERRIDE {
+    return new ClientGenericRpcContextStreamingImpl(channel_id_, stub_, req_,
+						    start_req_, callback_);
+  }
+  void Start(CompletionQueue* cq) GRPC_OVERRIDE {
+    const grpc::string kMethodName("/grpc.testing.BenchmarkService/StreamingCall");
+    stream_ = start_req_(stub_, &context_, kMethodName, cq, ClientRpcContext::tag(this));
+  }
+
+ private:
+  bool ReqSent(bool ok, Histogram*) { return StartWrite(ok); }
+  bool StartWrite(bool ok) {
+    if (!ok) {
+      return (false);
+    }
+    start_ = Timer::Now();
+    next_state_ = &ClientGenericRpcContextStreamingImpl::WriteDone;
+    stream_->Write(req_, ClientRpcContext::tag(this));
+    return true;
+  }
+  bool WriteDone(bool ok, Histogram*) {
+    if (!ok) {
+      return (false);
+    }
+    next_state_ = &ClientGenericRpcContextStreamingImpl::ReadDone;
+    stream_->Read(&response_, ClientRpcContext::tag(this));
+    return true;
+  }
+  bool ReadDone(bool ok, Histogram* hist) {
+    hist->Add((Timer::Now() - start_) * 1e9);
+    return StartWrite(ok);
+  }
+  grpc::ClientContext context_;
+  grpc::GenericStub* stub_;
+  ByteBuffer req_;
+  ByteBuffer response_;
+  bool (ClientGenericRpcContextStreamingImpl::*next_state_)(bool, Histogram*);
+  std::function<void(grpc::Status, ByteBuffer*)> callback_;
+  std::function<
+      std::unique_ptr<grpc::GenericClientAsyncReaderWriter>(
+          grpc::GenericStub*, grpc::ClientContext*, const grpc::string&, CompletionQueue*,
+          void*)> start_req_;
+  grpc::Status status_;
+  double start_;
+  std::unique_ptr<grpc::GenericClientAsyncReaderWriter> stream_;
+};
+
+class GenericAsyncStreamingClient GRPC_FINAL : public AsyncClient<grpc::GenericStub, ByteBuffer> {
+ public:
+  explicit GenericAsyncStreamingClient(const ClientConfig& config)
+    : AsyncClient(config, SetupCtx, grpc::GenericStub) {
+    // async streaming currently only supports closed loop
+    GPR_ASSERT(closed_loop_);
+
+    StartThreads(config.async_client_threads());
+  }
+
+  ~GenericAsyncStreamingClient() GRPC_OVERRIDE { EndThreads(); }
+
+ private:
+  static void CheckDone(grpc::Status s, ByteBuffer* response) {}
+  static std::unique_ptr<grpc::GenericClientAsyncReaderWriter>
+    StartReq(grpc::GenericStub* stub, grpc::ClientContext* ctx,
+	     const grpc::string& method_name, CompletionQueue* cq, void* tag) {
+    auto stream = stub->Call(ctx, method_name, cq, tag);
+    return stream;
+  };
+  static ClientRpcContext* SetupCtx(int channel_id,
+                                    grpc::GenericStub* stub,
+                                    const SimpleRequest& req) {
+    return new ClientRpcContextStreamingImpl<SimpleRequest, SimpleResponse>(
+        channel_id, stub, req, AsyncStreamingClient::StartReq,
+        AsyncStreamingClient::CheckDone);
+  }
+};
+
 std::unique_ptr<Client> CreateAsyncUnaryClient(const ClientConfig& args) {
   return std::unique_ptr<Client>(new AsyncUnaryClient(args));
 }
 std::unique_ptr<Client> CreateAsyncStreamingClient(const ClientConfig& args) {
   return std::unique_ptr<Client>(new AsyncStreamingClient(args));
 }
+std::unique_ptr<Client> CreateGenericAsyncStreamingClient(const ClientConfig& args) {
+  return std::unique_ptr<Client>(new GenericAsyncStreamingClient(args));
+}
 
 }  // namespace testing
 }  // namespace grpc
diff --git a/test/cpp/qps/server.h b/test/cpp/qps/server.h
index 6e81edc8ff..7c52443d4e 100644
--- a/test/cpp/qps/server.h
+++ b/test/cpp/qps/server.h
@@ -75,12 +75,11 @@ class Server {
   }
 
   static bool SetPayload(PayloadType type, int size, Payload* payload) {
-    PayloadType response_type = type;
     // TODO(yangg): Support UNCOMPRESSABLE payload.
     if (type != PayloadType::COMPRESSABLE) {
       return false;
     }
-    payload->set_type(response_type);
+    payload->set_type(type);
     std::unique_ptr<char[]> body(new char[size]());
     payload->set_body(body.get(), size);
     return true;
-- 
GitLab