diff --git a/setup.py b/setup.py index f96824fa8834ff467d911a698b95f16aa7e05497..0e2646d5d2fdf3f963186bedb3600fce3b216af3 100644 --- a/setup.py +++ b/setup.py @@ -202,7 +202,7 @@ TEST_PACKAGE_DATA = { } TESTS_REQUIRE = ( - 'oauth2client>=1.4.7', + 'oauth2client>=2.1.0', 'protobuf>=3.0.0a3', 'coverage>=4.0', ) + INSTALL_REQUIRES diff --git a/src/python/grpcio/grpc/_auth.py b/src/python/grpcio/grpc/_auth.py index 3ae00ca23a7f53abdffc4361c6fee6d0cbd1b1de..dea3221c9d89720893a57c0d25ef02c5b144342d 100644 --- a/src/python/grpcio/grpc/_auth.py +++ b/src/python/grpcio/grpc/_auth.py @@ -29,6 +29,7 @@ """GRPCAuthMetadataPlugins for standard authentication.""" +import inspect from concurrent import futures import grpc @@ -46,9 +47,21 @@ class GoogleCallCredentials(grpc.AuthMetadataPlugin): self._credentials = credentials self._pool = futures.ThreadPoolExecutor(max_workers=1) + # Hack to determine if these are JWT creds and we need to pass + # additional_claims when getting a token + if 'additional_claims' in inspect.getargspec( + credentials.get_access_token).args: + self._is_jwt = True + else: + self._is_jwt = False + def __call__(self, context, callback): # MetadataPlugins cannot block (see grpc.beta.interfaces.py) - future = self._pool.submit(self._credentials.get_access_token) + if self._is_jwt: + future = self._pool.submit(self._credentials.get_access_token, + additional_claims={'aud': context.service_url}) + else: + future = self._pool.submit(self._credentials.get_access_token) future.add_done_callback(lambda x: self._get_token_callback(callback, x)) def _get_token_callback(self, callback, future): diff --git a/src/python/grpcio/tests/interop/client.py b/src/python/grpcio/tests/interop/client.py index e3d5545a020c5ce81625ff5b3c3bca1461fba89f..8aa1ce30c1a821f779b10c3b7a80bc5ef201a17c 100644 --- a/src/python/grpcio/tests/interop/client.py +++ b/src/python/grpcio/tests/interop/client.py @@ -76,6 +76,9 @@ def _stub(args): creds = oauth2client_client.GoogleCredentials.get_application_default() scoped_creds = creds.create_scoped([args.oauth_scope]) call_creds = implementations.google_call_credentials(scoped_creds) + elif args.test_case == 'jwt_token_creds': + creds = oauth2client_client.GoogleCredentials.get_application_default() + call_creds = implementations.google_call_credentials(creds) else: call_creds = None if args.use_tls: diff --git a/src/python/grpcio/tests/interop/methods.py b/src/python/grpcio/tests/interop/methods.py index d5ef0c68bb9ada1d500cdb5afca56355b72ee914..7eac5115258d2c2cbdbcec0ea2864496028ee43d 100644 --- a/src/python/grpcio/tests/interop/methods.py +++ b/src/python/grpcio/tests/interop/methods.py @@ -310,6 +310,16 @@ def _oauth2_auth_token(stub, args): (response.oauth_scope, args.oauth_scope)) +def _jwt_token_creds(stub, args): + json_key_filename = os.environ[ + oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS] + wanted_email = json.load(open(json_key_filename, 'rb'))['client_email'] + response = _large_unary_common_behavior(stub, True, False) + if wanted_email != response.username: + raise ValueError( + 'expected username %s, got %s' % (wanted_email, response.username)) + + def _per_rpc_creds(stub, args): json_key_filename = os.environ[ oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS] @@ -338,6 +348,7 @@ class TestCase(enum.Enum): EMPTY_STREAM = 'empty_stream' COMPUTE_ENGINE_CREDS = 'compute_engine_creds' OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' + JWT_TOKEN_CREDS = 'jwt_token_creds' PER_RPC_CREDS = 'per_rpc_creds' TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' @@ -364,6 +375,8 @@ class TestCase(enum.Enum): _compute_engine_creds(stub, args) elif self is TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token(stub, args) + elif self is TestCase.JWT_TOKEN_CREDS: + _jwt_token_creds(stub, args) elif self is TestCase.PER_RPC_CREDS: _per_rpc_creds(stub, args) else: diff --git a/tools/run_tests/run_interop_tests.py b/tools/run_tests/run_interop_tests.py index 053aabc9b5b80d68c671b7c148b7054b33bb9a75..5aaefb1ae144dad5e7be99275671ca28d87b0657 100755 --- a/tools/run_tests/run_interop_tests.py +++ b/tools/run_tests/run_interop_tests.py @@ -317,7 +317,7 @@ class PythonLanguage: 'PYTHONPATH': '{}/src/python/gens'.format(DOCKER_WORKDIR_ROOT)} def unimplemented_test_cases(self): - return _SKIP_ADVANCED + _SKIP_COMPRESSION + ['jwt_token_creds'] + return _SKIP_ADVANCED + _SKIP_COMPRESSION def unimplemented_test_cases_server(self): return _SKIP_ADVANCED + _SKIP_COMPRESSION