From 22a65e1a2b7c1782bc288d9e8af1c0c9f0f90795 Mon Sep 17 00:00:00 2001
From: Ken Payson <kpayson@google.com>
Date: Tue, 7 Jun 2016 19:06:05 -0700
Subject: [PATCH] Added python jwt_token_creds interop test

---
 setup.py                                   |  2 +-
 src/python/grpcio/grpc/_auth.py            | 15 ++++++++++++++-
 src/python/grpcio/tests/interop/client.py  |  3 +++
 src/python/grpcio/tests/interop/methods.py | 13 +++++++++++++
 tools/run_tests/run_interop_tests.py       |  2 +-
 5 files changed, 32 insertions(+), 3 deletions(-)

diff --git a/setup.py b/setup.py
index f96824fa88..0e2646d5d2 100644
--- a/setup.py
+++ b/setup.py
@@ -202,7 +202,7 @@ TEST_PACKAGE_DATA = {
 }
 
 TESTS_REQUIRE = (
-    'oauth2client>=1.4.7',
+    'oauth2client>=2.1.0',
     'protobuf>=3.0.0a3',
     'coverage>=4.0',
 ) + INSTALL_REQUIRES
diff --git a/src/python/grpcio/grpc/_auth.py b/src/python/grpcio/grpc/_auth.py
index 3ae00ca23a..dea3221c9d 100644
--- a/src/python/grpcio/grpc/_auth.py
+++ b/src/python/grpcio/grpc/_auth.py
@@ -29,6 +29,7 @@
 
 """GRPCAuthMetadataPlugins for standard authentication."""
 
+import inspect
 from concurrent import futures
 
 import grpc
@@ -46,9 +47,21 @@ class GoogleCallCredentials(grpc.AuthMetadataPlugin):
     self._credentials = credentials
     self._pool = futures.ThreadPoolExecutor(max_workers=1)
 
+    # Hack to determine if these are JWT creds and we need to pass
+    # additional_claims when getting a token
+    if 'additional_claims' in inspect.getargspec(
+        credentials.get_access_token).args:
+      self._is_jwt = True
+    else:
+      self._is_jwt = False
+
   def __call__(self, context, callback):
     # MetadataPlugins cannot block (see grpc.beta.interfaces.py)
-    future = self._pool.submit(self._credentials.get_access_token)
+    if self._is_jwt:
+      future = self._pool.submit(self._credentials.get_access_token,
+                                 additional_claims={'aud': context.service_url})
+    else:
+      future = self._pool.submit(self._credentials.get_access_token)
     future.add_done_callback(lambda x: self._get_token_callback(callback, x))
 
   def _get_token_callback(self, callback, future):
diff --git a/src/python/grpcio/tests/interop/client.py b/src/python/grpcio/tests/interop/client.py
index e3d5545a02..8aa1ce30c1 100644
--- a/src/python/grpcio/tests/interop/client.py
+++ b/src/python/grpcio/tests/interop/client.py
@@ -76,6 +76,9 @@ def _stub(args):
     creds = oauth2client_client.GoogleCredentials.get_application_default()
     scoped_creds = creds.create_scoped([args.oauth_scope])
     call_creds = implementations.google_call_credentials(scoped_creds)
+  elif args.test_case == 'jwt_token_creds':
+    creds = oauth2client_client.GoogleCredentials.get_application_default()
+    call_creds = implementations.google_call_credentials(creds)
   else:
     call_creds = None
   if args.use_tls:
diff --git a/src/python/grpcio/tests/interop/methods.py b/src/python/grpcio/tests/interop/methods.py
index d5ef0c68bb..7eac511525 100644
--- a/src/python/grpcio/tests/interop/methods.py
+++ b/src/python/grpcio/tests/interop/methods.py
@@ -310,6 +310,16 @@ def _oauth2_auth_token(stub, args):
         (response.oauth_scope, args.oauth_scope))
 
 
+def _jwt_token_creds(stub, args):
+  json_key_filename = os.environ[
+      oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
+  wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
+  response = _large_unary_common_behavior(stub, True, False)
+  if wanted_email != response.username:
+    raise ValueError(
+        'expected username %s, got %s' % (wanted_email, response.username))
+
+
 def _per_rpc_creds(stub, args):
   json_key_filename = os.environ[
       oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
@@ -338,6 +348,7 @@ class TestCase(enum.Enum):
   EMPTY_STREAM = 'empty_stream'
   COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
   OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
+  JWT_TOKEN_CREDS = 'jwt_token_creds'
   PER_RPC_CREDS = 'per_rpc_creds'
   TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'
 
@@ -364,6 +375,8 @@ class TestCase(enum.Enum):
       _compute_engine_creds(stub, args)
     elif self is TestCase.OAUTH2_AUTH_TOKEN:
       _oauth2_auth_token(stub, args)
+    elif self is TestCase.JWT_TOKEN_CREDS:
+      _jwt_token_creds(stub, args)
     elif self is TestCase.PER_RPC_CREDS:
       _per_rpc_creds(stub, args)
     else:
diff --git a/tools/run_tests/run_interop_tests.py b/tools/run_tests/run_interop_tests.py
index 053aabc9b5..5aaefb1ae1 100755
--- a/tools/run_tests/run_interop_tests.py
+++ b/tools/run_tests/run_interop_tests.py
@@ -317,7 +317,7 @@ class PythonLanguage:
             'PYTHONPATH': '{}/src/python/gens'.format(DOCKER_WORKDIR_ROOT)}
 
   def unimplemented_test_cases(self):
-    return _SKIP_ADVANCED + _SKIP_COMPRESSION + ['jwt_token_creds']
+    return _SKIP_ADVANCED + _SKIP_COMPRESSION
 
   def unimplemented_test_cases_server(self):
     return _SKIP_ADVANCED + _SKIP_COMPRESSION
-- 
GitLab