Skip to content
Snippets Groups Projects
Commit 09bf5f45 authored by Jan Tattermusch's avatar Jan Tattermusch
Browse files

Merge pull request #6831 from kpayson64/python_jwt_creds

Added python jwt_token_creds interop test
parents 2c57371b 22a65e1a
No related branches found
No related tags found
No related merge requests found
...@@ -202,7 +202,7 @@ TEST_PACKAGE_DATA = { ...@@ -202,7 +202,7 @@ TEST_PACKAGE_DATA = {
} }
TESTS_REQUIRE = ( TESTS_REQUIRE = (
'oauth2client>=1.4.7', 'oauth2client>=2.1.0',
'protobuf>=3.0.0a3', 'protobuf>=3.0.0a3',
'coverage>=4.0', 'coverage>=4.0',
) + INSTALL_REQUIRES ) + INSTALL_REQUIRES
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
"""GRPCAuthMetadataPlugins for standard authentication.""" """GRPCAuthMetadataPlugins for standard authentication."""
import inspect
from concurrent import futures from concurrent import futures
import grpc import grpc
...@@ -46,9 +47,21 @@ class GoogleCallCredentials(grpc.AuthMetadataPlugin): ...@@ -46,9 +47,21 @@ class GoogleCallCredentials(grpc.AuthMetadataPlugin):
self._credentials = credentials self._credentials = credentials
self._pool = futures.ThreadPoolExecutor(max_workers=1) self._pool = futures.ThreadPoolExecutor(max_workers=1)
# Hack to determine if these are JWT creds and we need to pass
# additional_claims when getting a token
if 'additional_claims' in inspect.getargspec(
credentials.get_access_token).args:
self._is_jwt = True
else:
self._is_jwt = False
def __call__(self, context, callback): def __call__(self, context, callback):
# MetadataPlugins cannot block (see grpc.beta.interfaces.py) # MetadataPlugins cannot block (see grpc.beta.interfaces.py)
future = self._pool.submit(self._credentials.get_access_token) if self._is_jwt:
future = self._pool.submit(self._credentials.get_access_token,
additional_claims={'aud': context.service_url})
else:
future = self._pool.submit(self._credentials.get_access_token)
future.add_done_callback(lambda x: self._get_token_callback(callback, x)) future.add_done_callback(lambda x: self._get_token_callback(callback, x))
def _get_token_callback(self, callback, future): def _get_token_callback(self, callback, future):
......
...@@ -76,6 +76,9 @@ def _stub(args): ...@@ -76,6 +76,9 @@ def _stub(args):
creds = oauth2client_client.GoogleCredentials.get_application_default() creds = oauth2client_client.GoogleCredentials.get_application_default()
scoped_creds = creds.create_scoped([args.oauth_scope]) scoped_creds = creds.create_scoped([args.oauth_scope])
call_creds = implementations.google_call_credentials(scoped_creds) call_creds = implementations.google_call_credentials(scoped_creds)
elif args.test_case == 'jwt_token_creds':
creds = oauth2client_client.GoogleCredentials.get_application_default()
call_creds = implementations.google_call_credentials(creds)
else: else:
call_creds = None call_creds = None
if args.use_tls: if args.use_tls:
......
...@@ -310,6 +310,16 @@ def _oauth2_auth_token(stub, args): ...@@ -310,6 +310,16 @@ def _oauth2_auth_token(stub, args):
(response.oauth_scope, args.oauth_scope)) (response.oauth_scope, args.oauth_scope))
def _jwt_token_creds(stub, args):
json_key_filename = os.environ[
oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
response = _large_unary_common_behavior(stub, True, False)
if wanted_email != response.username:
raise ValueError(
'expected username %s, got %s' % (wanted_email, response.username))
def _per_rpc_creds(stub, args): def _per_rpc_creds(stub, args):
json_key_filename = os.environ[ json_key_filename = os.environ[
oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS] oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
...@@ -338,6 +348,7 @@ class TestCase(enum.Enum): ...@@ -338,6 +348,7 @@ class TestCase(enum.Enum):
EMPTY_STREAM = 'empty_stream' EMPTY_STREAM = 'empty_stream'
COMPUTE_ENGINE_CREDS = 'compute_engine_creds' COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
JWT_TOKEN_CREDS = 'jwt_token_creds'
PER_RPC_CREDS = 'per_rpc_creds' PER_RPC_CREDS = 'per_rpc_creds'
TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'
...@@ -364,6 +375,8 @@ class TestCase(enum.Enum): ...@@ -364,6 +375,8 @@ class TestCase(enum.Enum):
_compute_engine_creds(stub, args) _compute_engine_creds(stub, args)
elif self is TestCase.OAUTH2_AUTH_TOKEN: elif self is TestCase.OAUTH2_AUTH_TOKEN:
_oauth2_auth_token(stub, args) _oauth2_auth_token(stub, args)
elif self is TestCase.JWT_TOKEN_CREDS:
_jwt_token_creds(stub, args)
elif self is TestCase.PER_RPC_CREDS: elif self is TestCase.PER_RPC_CREDS:
_per_rpc_creds(stub, args) _per_rpc_creds(stub, args)
else: else:
......
...@@ -317,7 +317,7 @@ class PythonLanguage: ...@@ -317,7 +317,7 @@ class PythonLanguage:
'PYTHONPATH': '{}/src/python/gens'.format(DOCKER_WORKDIR_ROOT)} 'PYTHONPATH': '{}/src/python/gens'.format(DOCKER_WORKDIR_ROOT)}
def unimplemented_test_cases(self): def unimplemented_test_cases(self):
return _SKIP_ADVANCED + _SKIP_COMPRESSION + ['jwt_token_creds'] return _SKIP_ADVANCED + _SKIP_COMPRESSION
def unimplemented_test_cases_server(self): def unimplemented_test_cases_server(self):
return _SKIP_ADVANCED + _SKIP_COMPRESSION return _SKIP_ADVANCED + _SKIP_COMPRESSION
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment