From d10bbb63f8d441319342b4d6e6d3d1f27b2533a1 Mon Sep 17 00:00:00 2001
From: Vijay Pai <vpai@google.com>
Date: Wed, 3 Feb 2016 11:26:23 -0800
Subject: [PATCH] Refactor to favor composition over inheritance Also make
 num_threads and num_rpcs as command-line flags

---
 test/cpp/end2end/thread_stress_test.cc | 81 ++++++++++++++++----------
 1 file changed, 49 insertions(+), 32 deletions(-)

diff --git a/test/cpp/end2end/thread_stress_test.cc b/test/cpp/end2end/thread_stress_test.cc
index f6d8475c70..01b45a5373 100644
--- a/test/cpp/end2end/thread_stress_test.cc
+++ b/test/cpp/end2end/thread_stress_test.cc
@@ -34,6 +34,7 @@
 #include <mutex>
 #include <thread>
 
+#include <gflags/gflags.h>
 #include <grpc++/channel.h>
 #include <grpc++/client_context.h>
 #include <grpc++/create_channel.h>
@@ -54,6 +55,9 @@ using grpc::testing::EchoRequest;
 using grpc::testing::EchoResponse;
 using std::chrono::system_clock;
 
+DEFINE_int32(num_threads, 100, "Number of threads");
+DEFINE_int32(num_rpcs, 1000, "Number of RPCs per thread");
+
 namespace grpc {
 namespace testing {
 
@@ -168,11 +172,10 @@ class TestServiceImplDupPkg
   }
 };
 
-class End2endTest : public ::testing::Test {
- protected:
-  End2endTest() : kMaxMessageSize_(8192) {}
-
-  void SetUp() GRPC_OVERRIDE {
+class CommonStressTest {
+ public:
+  CommonStressTest() : kMaxMessageSize_(8192) {}
+  void SetUp() {
     int port = grpc_pick_unused_port_or_die();
     server_address_ << "localhost:" << port;
     // Setup server
@@ -185,15 +188,15 @@ class End2endTest : public ::testing::Test {
     builder.RegisterService(&dup_pkg_service_);
     server_ = builder.BuildAndStart();
   }
-
-  void TearDown() GRPC_OVERRIDE { server_->Shutdown(); }
-
+  void TearDown() { server_->Shutdown(); }
   void ResetStub() {
     std::shared_ptr<Channel> channel =
         CreateChannel(server_address_.str(), InsecureChannelCredentials());
     stub_ = grpc::testing::EchoTestService::NewStub(channel);
   }
+  grpc::testing::EchoTestService::Stub* GetStub() { return stub_.get(); }
 
+ private:
   std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
   std::unique_ptr<Server> server_;
   std::ostringstream server_address_;
@@ -202,6 +205,16 @@ class End2endTest : public ::testing::Test {
   TestServiceImplDupPkg dup_pkg_service_;
 };
 
+class End2endTest : public ::testing::Test {
+ protected:
+  End2endTest() {}
+  void SetUp() GRPC_OVERRIDE { common_.SetUp(); }
+  void TearDown() GRPC_OVERRIDE { common_.TearDown(); }
+  void ResetStub() { common_.ResetStub(); }
+
+  CommonStressTest common_;
+};
+
 static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs) {
   EchoRequest request;
   EchoResponse response;
@@ -216,27 +229,29 @@ static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs) {
 }
 
 TEST_F(End2endTest, ThreadStress) {
-  ResetStub();
+  common_.ResetStub();
   std::vector<std::thread*> threads;
-  for (int i = 0; i < 100; ++i) {
-    threads.push_back(new std::thread(SendRpc, stub_.get(), 1000));
+  for (int i = 0; i < FLAGS_num_threads; ++i) {
+    threads.push_back(
+        new std::thread(SendRpc, common_.GetStub(), FLAGS_num_rpcs));
   }
-  for (int i = 0; i < 100; ++i) {
+  for (int i = 0; i < FLAGS_num_threads; ++i) {
     threads[i]->join();
     delete threads[i];
   }
 }
 
-class AsyncClientEnd2endTest : public End2endTest {
+class AsyncClientEnd2endTest : public ::testing::Test {
  protected:
   AsyncClientEnd2endTest() : rpcs_outstanding_(0) {}
 
+  void SetUp() GRPC_OVERRIDE { common_.SetUp(); }
   void TearDown() GRPC_OVERRIDE {
     void* ignored_tag;
     bool ignored_ok;
     while (cq_.Next(&ignored_tag, &ignored_ok))
       ;
-    End2endTest::TearDown();
+    common_.TearDown();
   }
 
   void Wait() {
@@ -260,7 +275,8 @@ class AsyncClientEnd2endTest : public End2endTest {
       AsyncClientCall* call = new AsyncClientCall;
       EchoRequest request;
       request.set_message("Hello");
-      call->response_reader = stub_->AsyncEcho(&call->context, request, &cq_);
+      call->response_reader =
+          common_.GetStub()->AsyncEcho(&call->context, request, &cq_);
       call->response_reader->Finish(&call->response, &call->status,
                                     (void*)call);
 
@@ -290,6 +306,7 @@ class AsyncClientEnd2endTest : public End2endTest {
     }
   }
 
+  CommonStressTest common_;
   CompletionQueue cq_;
   std::mutex mu_;
   std::condition_variable cv_;
@@ -297,27 +314,26 @@ class AsyncClientEnd2endTest : public End2endTest {
 };
 
 TEST_F(AsyncClientEnd2endTest, ThreadStress) {
-  ResetStub();
-  std::vector<std::thread*> threads;
-  for (int i = 0; i < 100; ++i) {
-    threads.push_back(new std::thread(
-        &AsyncClientEnd2endTest_ThreadStress_Test::AsyncSendRpc, this, 1000));
+  common_.ResetStub();
+  std::vector<std::thread*> send_threads, completion_threads;
+  for (int i = 0; i < FLAGS_num_threads; ++i) {
+    completion_threads.push_back(new std::thread(
+        &AsyncClientEnd2endTest_ThreadStress_Test::AsyncCompleteRpc, this));
   }
-  for (int i = 0; i < 100; ++i) {
-    threads[i]->join();
-    delete threads[i];
+  for (int i = 0; i < FLAGS_num_threads; ++i) {
+    send_threads.push_back(
+        new std::thread(&AsyncClientEnd2endTest_ThreadStress_Test::AsyncSendRpc,
+                        this, FLAGS_num_rpcs));
   }
-
-  threads.clear();
-
-  for (int i = 0; i < 100; ++i) {
-    threads.push_back(new std::thread(
-        &AsyncClientEnd2endTest_ThreadStress_Test::AsyncCompleteRpc, this));
+  for (int i = 0; i < FLAGS_num_threads; ++i) {
+    send_threads[i]->join();
+    delete send_threads[i];
   }
+
   Wait();
-  for (int i = 0; i < 100; ++i) {
-    threads[i]->join();
-    delete threads[i];
+  for (int i = 0; i < FLAGS_num_threads; ++i) {
+    completion_threads[i]->join();
+    delete completion_threads[i];
   }
 }
 
@@ -325,6 +341,7 @@ TEST_F(AsyncClientEnd2endTest, ThreadStress) {
 }  // namespace grpc
 
 int main(int argc, char** argv) {
+  ::google::ParseCommandLineFlags(&argc, &argv, true);
   grpc_test_init(argc, argv);
   ::testing::InitGoogleTest(&argc, argv);
   return RUN_ALL_TESTS();
-- 
GitLab