From 24804d25b479cc6a32ac40184d6a06e1258748c2 Mon Sep 17 00:00:00 2001
From: Noah Eisen <ncteisen@google.com>
Date: Mon, 21 Nov 2016 16:34:42 -0800
Subject: [PATCH] Add new flags to Python stress client

Add use_tls, use_test_ca, and server_host_override flags to Python
stress client. These are needed to run the stress client against a local
server that is using tls.
---
 .../grpcio_tests/tests/stress/client.py       | 26 ++++++++++++++++++-
 1 file changed, 25 insertions(+), 1 deletion(-)

diff --git a/src/python/grpcio_tests/tests/stress/client.py b/src/python/grpcio_tests/tests/stress/client.py
index 975f33b4c1..390ea13021 100644
--- a/src/python/grpcio_tests/tests/stress/client.py
+++ b/src/python/grpcio_tests/tests/stress/client.py
@@ -39,6 +39,7 @@ from src.proto.grpc.testing import metrics_pb2
 from src.proto.grpc.testing import test_pb2
 
 from tests.interop import methods
+from tests.interop import resources
 from tests.qps import histogram
 from tests.stress import metrics_server
 from tests.stress import test_runner
@@ -71,6 +72,16 @@ def _args():
       '--metrics_port',
       help='the port to listen for metrics requests on',
       default=8081, type=int)
+  parser.add_argument(
+      '--use_test_ca',
+      help='Whether to use our fake CA. Requires --use_tls=true',
+      default=False, type=bool)
+  parser.add_argument(
+      '--use_tls',
+      help='Whether to use TLS', default=False, type=bool)
+  parser.add_argument(
+      '--server_host_override', default="foo.test.google.fr",
+      help='the server host to which to claim to connect', type=str)
   return parser.parse_args()
 
 
@@ -90,6 +101,19 @@ def _parse_weighted_test_cases(test_case_args):
     weighted_test_cases[test_case] = int(weight)
   return weighted_test_cases
 
+def _get_channel(target, args):
+  if args.use_tls:
+    if args.use_test_ca:
+      root_certificates = resources.test_root_certificates()
+    else:
+      root_certificates = None  # will load default roots.
+    channel_credentials = grpc.ssl_channel_credentials(
+        root_certificates=root_certificates)
+    options = (('grpc.ssl_target_name_override', args.server_host_override,),)
+    return grpc.secure_channel(
+        target, channel_credentials, options=options)
+  else:
+    return grpc.insecure_channel(target)
 
 def run_test(args):
   test_cases = _parse_weighted_test_cases(args.test_cases)
@@ -108,7 +132,7 @@ def run_test(args):
 
   for test_server_target in test_server_targets:
     for _ in xrange(args.num_channels_per_server):
-      channel = grpc.insecure_channel(test_server_target)
+      channel = _get_channel(test_server_target, args)
       for _ in xrange(args.num_stubs_per_channel):
         stub = test_pb2.TestServiceStub(channel)
         runner = test_runner.TestRunner(stub, test_cases, hist,
-- 
GitLab