From 21e5d2b2f10bc6d175caf7ad44d729517528a78a Mon Sep 17 00:00:00 2001
From: vjpai <vpai@google.com>
Date: Tue, 2 Feb 2016 09:36:36 -0800
Subject: [PATCH] Add a Quit RPC so that we can conveniently shut down the
 workers from the driver.

---
 src/proto/grpc/testing/control.proto  |  3 +++
 src/proto/grpc/testing/services.proto |  3 +++
 test/cpp/qps/driver.cc                | 13 +++++++++++++
 test/cpp/qps/driver.h                 |  1 +
 test/cpp/qps/qps_driver.cc            |  7 +++++++
 test/cpp/qps/qps_worker.cc            | 24 +++++++++++++++++++++---
 test/cpp/qps/qps_worker.h             |  6 ++++++
 test/cpp/qps/worker.cc                |  2 +-
 8 files changed, 55 insertions(+), 4 deletions(-)

diff --git a/src/proto/grpc/testing/control.proto b/src/proto/grpc/testing/control.proto
index 2f352e652f..d135cb4d32 100644
--- a/src/proto/grpc/testing/control.proto
+++ b/src/proto/grpc/testing/control.proto
@@ -150,3 +150,6 @@ message ServerStatus {
   // Number of cores on the server. See gpr_cpu_num_cores.
   int32 cores = 3;
 }
+
+message Void {
+}
diff --git a/src/proto/grpc/testing/services.proto b/src/proto/grpc/testing/services.proto
index af285ceab8..57cd9ecf76 100644
--- a/src/proto/grpc/testing/services.proto
+++ b/src/proto/grpc/testing/services.proto
@@ -62,4 +62,7 @@ service WorkerService {
   // and once the shutdown has finished, the OK status is sent to terminate
   // this RPC.
   rpc RunClient(stream ClientArgs) returns (stream ClientStatus);
+
+  // Quit this worker
+  rpc QuitWorker(Void) returns (Void);
 }
diff --git a/test/cpp/qps/driver.cc b/test/cpp/qps/driver.cc
index 490156aec2..370a3a834a 100644
--- a/test/cpp/qps/driver.cc
+++ b/test/cpp/qps/driver.cc
@@ -283,5 +283,18 @@ std::unique_ptr<ScenarioResult> RunScenario(
   delete[] servers;
   return result;
 }
+
+void RunQuit() {
+  // Get client, server lists
+  auto workers = get_hosts("QPS_WORKERS");
+  for (size_t i = 0; i < workers.size(); i++) {
+    auto stub = WorkerService::NewStub(
+       CreateChannel(workers[i], InsecureChannelCredentials()));
+    Void dummy;
+    grpc::ClientContext ctx;
+    GPR_ASSERT(stub->QuitWorker(&ctx, dummy, &dummy).ok());
+  }
+}
+
 }  // namespace testing
 }  // namespace grpc
diff --git a/test/cpp/qps/driver.h b/test/cpp/qps/driver.h
index 2a7cf805e5..4b2b400c0c 100644
--- a/test/cpp/qps/driver.h
+++ b/test/cpp/qps/driver.h
@@ -70,6 +70,7 @@ std::unique_ptr<ScenarioResult> RunScenario(
     const grpc::testing::ServerConfig& server_config, size_t num_servers,
     int warmup_seconds, int benchmark_seconds, int spawn_local_worker_count);
 
+void RunQuit();
 }  // namespace testing
 }  // namespace grpc
 
diff --git a/test/cpp/qps/qps_driver.cc b/test/cpp/qps/qps_driver.cc
index aa3cb68821..1fe37b1667 100644
--- a/test/cpp/qps/qps_driver.cc
+++ b/test/cpp/qps/qps_driver.cc
@@ -77,6 +77,8 @@ DEFINE_double(pareto_alpha, -1.0, "Pareto alpha value");
 
 DEFINE_bool(secure_test, false, "Run a secure test");
 
+DEFINE_bool(quit, false, "Quit the workers");
+
 using grpc::testing::ClientConfig;
 using grpc::testing::ServerConfig;
 using grpc::testing::ClientType;
@@ -89,6 +91,11 @@ namespace grpc {
 namespace testing {
 
 static void QpsDriver() {
+  if (FLAGS_quit) {
+    RunQuit();
+    return;
+  }
+
   RpcType rpc_type;
   GPR_ASSERT(RpcType_Parse(FLAGS_rpc_type, &rpc_type));
 
diff --git a/test/cpp/qps/qps_worker.cc b/test/cpp/qps/qps_worker.cc
index 6316605aaf..f33b13b5b2 100644
--- a/test/cpp/qps/qps_worker.cc
+++ b/test/cpp/qps/qps_worker.cc
@@ -107,8 +107,8 @@ static std::unique_ptr<Server> CreateServer(const ServerConfig& config) {
 
 class WorkerServiceImpl GRPC_FINAL : public WorkerService::Service {
  public:
-  explicit WorkerServiceImpl(int server_port)
-      : acquired_(false), server_port_(server_port) {}
+  WorkerServiceImpl(int server_port, QpsWorker *worker)
+    : acquired_(false), server_port_(server_port), worker_(worker) {}
 
   Status RunClient(ServerContext* ctx,
                    ServerReaderWriter<ClientStatus, ClientArgs>* stream)
@@ -138,6 +138,16 @@ class WorkerServiceImpl GRPC_FINAL : public WorkerService::Service {
     return ret;
   }
 
+  Status QuitWorker(ServerContext *ctx, const Void*, Void*) GRPC_OVERRIDE {
+    InstanceGuard g(this);
+    if (!g.Acquired()) {
+      return Status(StatusCode::RESOURCE_EXHAUSTED, "");
+    }
+
+    worker_->MarkDone();
+    return Status::OK;
+  }
+  
  private:
   // Protect against multiple clients using this worker at once.
   class InstanceGuard {
@@ -248,10 +258,12 @@ class WorkerServiceImpl GRPC_FINAL : public WorkerService::Service {
   std::mutex mu_;
   bool acquired_;
   int server_port_;
+  QpsWorker *worker_;
 };
 
 QpsWorker::QpsWorker(int driver_port, int server_port) {
-  impl_.reset(new WorkerServiceImpl(server_port));
+  impl_.reset(new WorkerServiceImpl(server_port, this));
+  gpr_atm_rel_store(&done_, static_cast<gpr_atm>(0));
 
   char* server_address = NULL;
   gpr_join_host_port(&server_address, "::", driver_port);
@@ -267,5 +279,11 @@ QpsWorker::QpsWorker(int driver_port, int server_port) {
 
 QpsWorker::~QpsWorker() {}
 
+bool QpsWorker::Done() const {
+  return (gpr_atm_acq_load(&done_) != static_cast<gpr_atm>(0));
+}
+void QpsWorker::MarkDone() {
+  gpr_atm_rel_store(&done_, static_cast<gpr_atm>(1));
+}
 }  // namespace testing
 }  // namespace grpc
diff --git a/test/cpp/qps/qps_worker.h b/test/cpp/qps/qps_worker.h
index 27de69fa65..f14a5c95ad 100644
--- a/test/cpp/qps/qps_worker.h
+++ b/test/cpp/qps/qps_worker.h
@@ -36,6 +36,8 @@
 
 #include <memory>
 
+#include <grpc/support/atm.h>
+
 namespace grpc {
 
 class Server;
@@ -49,9 +51,13 @@ class QpsWorker {
   explicit QpsWorker(int driver_port, int server_port = 0);
   ~QpsWorker();
 
+  bool Done() const;
+  void MarkDone();
  private:
   std::unique_ptr<WorkerServiceImpl> impl_;
   std::unique_ptr<Server> server_;
+
+  gpr_atm done_;
 };
 
 }  // namespace testing
diff --git a/test/cpp/qps/worker.cc b/test/cpp/qps/worker.cc
index a1e73e9abe..f42cfe3255 100644
--- a/test/cpp/qps/worker.cc
+++ b/test/cpp/qps/worker.cc
@@ -56,7 +56,7 @@ namespace testing {
 static void RunServer() {
   QpsWorker worker(FLAGS_driver_port, FLAGS_server_port);
 
-  while (!got_sigint) {
+  while (!got_sigint && !worker.Done()) {
     gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
                                  gpr_time_from_seconds(5, GPR_TIMESPAN)));
   }
-- 
GitLab