From 9785c8f9c43bddebb4e01162f2f0e702521b861c Mon Sep 17 00:00:00 2001
From: Noah Eisen <ncteisen@google.com>
Date: Tue, 25 Oct 2016 12:04:24 -0700
Subject: [PATCH] Implement the advanced interop tests for Python

Add the code for three new interop tests: unimplemented_method,
unimplemented_service, and custom_metadata.

Fix and refactor the code for status_code_and_message.
---
 .../grpcio_tests/tests/interop/client.py      |   5 +-
 .../grpcio_tests/tests/interop/methods.py     | 188 +++++++++++++-----
 tools/run_tests/run_interop_tests.py          |   4 +-
 3 files changed, 145 insertions(+), 52 deletions(-)

diff --git a/src/python/grpcio_tests/tests/interop/client.py b/src/python/grpcio_tests/tests/interop/client.py
index 9d61d18975..4fbf58f7d9 100644
--- a/src/python/grpcio_tests/tests/interop/client.py
+++ b/src/python/grpcio_tests/tests/interop/client.py
@@ -106,7 +106,10 @@ def _stub(args):
         (('grpc.ssl_target_name_override', args.server_host_override,),))
   else:
     channel = grpc.insecure_channel(target)
-  return test_pb2.TestServiceStub(channel)
+  if args.test_case == "unimplemented_service":
+    return test_pb2.UnimplementedServiceStub(channel)
+  else:
+    return test_pb2.TestServiceStub(channel)
 
 
 def _test_case_from_arg(test_case_arg):
diff --git a/src/python/grpcio_tests/tests/interop/methods.py b/src/python/grpcio_tests/tests/interop/methods.py
index 7edd75c56c..52e56f3502 100644
--- a/src/python/grpcio_tests/tests/interop/methods.py
+++ b/src/python/grpcio_tests/tests/interop/methods.py
@@ -44,25 +44,43 @@ from src.proto.grpc.testing import empty_pb2
 from src.proto.grpc.testing import messages_pb2
 from src.proto.grpc.testing import test_pb2
 
+_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"
+_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"
+
+def _maybe_echo_metadata(servicer_context):
+  """Copies metadata from request to response if it is present."""
+  invocation_metadata = dict(servicer_context.invocation_metadata())
+  if _INITIAL_METADATA_KEY in invocation_metadata:
+    initial_metadatum = (
+        _INITIAL_METADATA_KEY, invocation_metadata[_INITIAL_METADATA_KEY])
+    servicer_context.send_initial_metadata((initial_metadatum,))
+  if _TRAILING_METADATA_KEY in invocation_metadata:
+    trailing_metadatum = (
+      _TRAILING_METADATA_KEY, invocation_metadata[_TRAILING_METADATA_KEY])
+    servicer_context.set_trailing_metadata((trailing_metadatum,))
+
+def _maybe_echo_status_and_message(request, servicer_context):
+  """Sets the response context code and details if the request asks for them"""
+  if request.HasField('response_status'):
+    servicer_context.set_code(request.response_status.code)
+    servicer_context.set_details(request.response_status.message)
 
 class TestService(test_pb2.TestServiceServicer):
 
   def EmptyCall(self, request, context):
+    _maybe_echo_metadata(context)
     return empty_pb2.Empty()
 
   def UnaryCall(self, request, context):
-    if request.HasField('response_status'):
-      context.set_code(request.response_status.code)
-      context.set_details(request.response_status.message)
+    _maybe_echo_metadata(context)
+    _maybe_echo_status_and_message(request, context)
     return messages_pb2.SimpleResponse(
         payload=messages_pb2.Payload(
             type=messages_pb2.COMPRESSABLE,
             body=b'\x00' * request.response_size))
 
   def StreamingOutputCall(self, request, context):
-    if request.HasField('response_status'):
-      context.set_code(request.response_status.code)
-      context.set_details(request.response_status.message)
+    _maybe_echo_status_and_message(request, context)
     for response_parameters in request.response_parameters:
       yield messages_pb2.StreamingOutputCallResponse(
           payload=messages_pb2.Payload(
@@ -78,10 +96,9 @@ class TestService(test_pb2.TestServiceServicer):
         aggregated_payload_size=aggregate_size)
 
   def FullDuplexCall(self, request_iterator, context):
+    _maybe_echo_metadata(context)
     for request in request_iterator:
-      if request.HasField('response_status'):
-        context.set_code(request.response_status.code)
-        context.set_details(request.response_status.message)
+      _maybe_echo_status_and_message(request, context)
       for response_parameters in request.response_parameters:
         yield messages_pb2.StreamingOutputCallResponse(
             payload=messages_pb2.Payload(
@@ -94,23 +111,46 @@ class TestService(test_pb2.TestServiceServicer):
     return self.FullDuplexCall(request_iterator, context)
 
 
+def _expect_status_code(call, expected_code):
+  if call.code() != expected_code:
+    raise ValueError(
+      'expected code %s, got %s' % (expected_code, call.code()))
+
+
+def _expect_status_details(call, expected_details):
+  if call.details() != expected_details:
+    raise ValueError(
+      'expected message %s, got %s' % (expected_details, call.details()))
+
+
+def _validate_status_code_and_details(call, expected_code, expected_details):
+  _expect_status_code(call, expected_code)
+  _expect_status_details(call, expected_details)
+
+
+def _validate_payload_type_and_length(response, expected_type, expected_length):
+  if response.payload.type is not expected_type:
+    raise ValueError(
+      'expected payload type %s, got %s' %
+          (expected_type, type(response.payload.type)))
+  elif len(response.payload.body) != expected_length:
+    raise ValueError(
+      'expected payload body size %d, got %d' %
+          (expected_length, len(response.payload.body)))
+
+
 def _large_unary_common_behavior(
     stub, fill_username, fill_oauth_scope, call_credentials):
+  size = 314159
   request = messages_pb2.SimpleRequest(
-      response_type=messages_pb2.COMPRESSABLE, response_size=314159,
+      response_type=messages_pb2.COMPRESSABLE, response_size=size,
       payload=messages_pb2.Payload(body=b'\x00' * 271828),
       fill_username=fill_username, fill_oauth_scope=fill_oauth_scope)
   response_future = stub.UnaryCall.future(
       request, credentials=call_credentials)
   response = response_future.result()
-  if response.payload.type is not messages_pb2.COMPRESSABLE:
-    raise ValueError(
-        'response payload type is "%s"!' % type(response.payload.type))
-  elif len(response.payload.body) != 314159:
-    raise ValueError(
-        'response body of incorrect size %d!' % len(response.payload.body))
-  else:
-    return response
+  _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
+  return response
 
 
 def _empty_unary(stub):
@@ -152,12 +192,9 @@ def _server_streaming(stub):
   )
   response_iterator = stub.StreamingOutputCall(request)
   for index, response in enumerate(response_iterator):
-    if response.payload.type != messages_pb2.COMPRESSABLE:
-      raise ValueError(
-          'response body of invalid type %s!' % response.payload.type)
-    elif len(response.payload.body) != sizes[index]:
-      raise ValueError(
-          'response body of invalid size %d!' % len(response.payload.body))
+    _validate_payload_type_and_length(
+        response, messages_pb2.COMPRESSABLE, sizes[index])
+
 
 def _cancel_after_begin(stub):
   sizes = (27182, 8, 1828, 45904,)
@@ -224,12 +261,8 @@ def _ping_pong(stub):
           payload=messages_pb2.Payload(body=b'\x00' * payload_size))
       pipe.add(request)
       response = next(response_iterator)
-      if response.payload.type != messages_pb2.COMPRESSABLE:
-        raise ValueError(
-            'response body of invalid type %s!' % response.payload.type)
-      if len(response.payload.body) != response_size:
-        raise ValueError(
-            'response body of invalid size %d!' % len(response.payload.body))
+      _validate_payload_type_and_length(
+          response, messages_pb2.COMPRESSABLE, response_size)
 
 
 def _cancel_after_first_response(stub):
@@ -291,36 +324,84 @@ def _empty_stream(stub):
 
 
 def _status_code_and_message(stub):
-  message = 'test status message'
+  details = 'test status message'
   code = 2
   status = grpc.StatusCode.UNKNOWN # code = 2
+
+  # Test with a UnaryCall
   request = messages_pb2.SimpleRequest(
       response_type=messages_pb2.COMPRESSABLE,
       response_size=1,
       payload=messages_pb2.Payload(body=b'\x00'),
-      response_status=messages_pb2.EchoStatus(code=code, message=message)
+      response_status=messages_pb2.EchoStatus(code=code, message=details)
   )
   response_future = stub.UnaryCall.future(request)
-  if response_future.code() != status:
-    raise ValueError(
-      'expected code %s, got %s' % (status, response_future.code()))
-  elif response_future.details() != message:
-    raise ValueError(
-      'expected message %s, got %s' % (message, response_future.details()))
+  _validate_status_code_and_details(response_future, status, details)
 
-  request = messages_pb2.StreamingOutputCallRequest(
+  # Test with a FullDuplexCall
+  with _Pipe() as pipe:
+    response_iterator = stub.FullDuplexCall(pipe)
+    request = messages_pb2.StreamingOutputCallRequest(
+        response_type=messages_pb2.COMPRESSABLE,
+        response_parameters=(
+            messages_pb2.ResponseParameters(size=1),),
+        payload=messages_pb2.Payload(body=b'\x00'),
+        response_status=messages_pb2.EchoStatus(code=code, message=details))
+    pipe.add(request)   # sends the initial request.
+  # Dropping out of with block closes the pipe
+  _validate_status_code_and_details(response_iterator, status, details)
+
+
+def _unimplemented_method(test_service_stub):
+  response_future = (
+      test_service_stub.UnimplementedCall.future(empty_pb2.Empty()))
+  _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
+
+
+def _unimplemented_service(unimplemented_service_stub):
+  response_future = (
+      unimplemented_service_stub.UnimplementedCall.future(empty_pb2.Empty()))
+  _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
+
+
+def _custom_metadata(stub):
+  initial_metadata_value = "test_initial_metadata_value"
+  trailing_metadata_value = "\x0a\x0b\x0a\x0b\x0a\x0b"
+  metadata = (
+      (_INITIAL_METADATA_KEY, initial_metadata_value),
+      (_TRAILING_METADATA_KEY, trailing_metadata_value))
+
+  def _validate_metadata(response):
+    initial_metadata = dict(response.initial_metadata())
+    if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
+      raise ValueError(
+        'expected initial metadata %s, got %s' % (
+            initial_metadata_value, initial_metadata[_INITIAL_METADATA_KEY]))
+    trailing_metadata = dict(response.trailing_metadata())
+    if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
+      raise ValueError(
+        'expected trailing metadata %s, got %s' % (
+            trailing_metadata_value, initial_metadata[_TRAILING_METADATA_KEY]))
+
+  # Testing with UnaryCall
+  request = messages_pb2.SimpleRequest(
       response_type=messages_pb2.COMPRESSABLE,
-      response_parameters=(
-          messages_pb2.ResponseParameters(size=1),),
-      response_status=messages_pb2.EchoStatus(code=code, message=message))
-  response_iterator = stub.StreamingOutputCall(request)
-  if response_future.code() != status:
-    raise ValueError(
-      'expected code %s, got %s' % (status, response_iterator.code()))
-  elif response_future.details() != message:
-    raise ValueError(
-      'expected message %s, got %s' % (message, response_iterator.details()))
+      response_size=1,
+      payload=messages_pb2.Payload(body=b'\x00'))
+  response_future = stub.UnaryCall.future(request, metadata=metadata)
+  _validate_metadata(response_future)
 
+  # Testing with FullDuplexCall
+  with _Pipe() as pipe:
+    response_iterator = stub.FullDuplexCall(pipe, metadata=metadata)
+    request = messages_pb2.StreamingOutputCallRequest(
+        response_type=messages_pb2.COMPRESSABLE,
+        response_parameters=(
+            messages_pb2.ResponseParameters(size=1),))
+    pipe.add(request)   # Sends the request
+    next(response_iterator)    # Causes server to send trailing metadata
+  # Dropping out of the with block closes the pipe
+  _validate_metadata(response_iterator)
 
 def _compute_engine_creds(stub, args):
   response = _large_unary_common_behavior(stub, True, True, None)
@@ -381,6 +462,9 @@ class TestCase(enum.Enum):
   CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response'
   EMPTY_STREAM = 'empty_stream'
   STATUS_CODE_AND_MESSAGE = 'status_code_and_message'
+  UNIMPLEMENTED_METHOD = 'unimplemented_method'
+  UNIMPLEMENTED_SERVICE = 'unimplemented_service'
+  CUSTOM_METADATA = "custom_metadata"
   COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
   OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
   JWT_TOKEN_CREDS = 'jwt_token_creds'
@@ -408,6 +492,12 @@ class TestCase(enum.Enum):
       _empty_stream(stub)
     elif self is TestCase.STATUS_CODE_AND_MESSAGE:
       _status_code_and_message(stub)
+    elif self is TestCase.UNIMPLEMENTED_METHOD:
+      _unimplemented_method(stub)
+    elif self is TestCase.UNIMPLEMENTED_SERVICE:
+      _unimplemented_service(stub)
+    elif self is TestCase.CUSTOM_METADATA:
+      _custom_metadata(stub)
     elif self is TestCase.COMPUTE_ENGINE_CREDS:
       _compute_engine_creds(stub, args)
     elif self is TestCase.OAUTH2_AUTH_TOKEN:
diff --git a/tools/run_tests/run_interop_tests.py b/tools/run_tests/run_interop_tests.py
index 29f6533398..0c6efda1f4 100755
--- a/tools/run_tests/run_interop_tests.py
+++ b/tools/run_tests/run_interop_tests.py
@@ -385,10 +385,10 @@ class PythonLanguage:
             'PYTHONPATH': '{}/src/python/gens'.format(DOCKER_WORKDIR_ROOT)}
 
   def unimplemented_test_cases(self):
-    return _SKIP_ADVANCED + _SKIP_COMPRESSION
+    return _SKIP_COMPRESSION
 
   def unimplemented_test_cases_server(self):
-    return _SKIP_ADVANCED + _SKIP_COMPRESSION
+    return _SKIP_COMPRESSION
 
   def __str__(self):
     return 'python'
-- 
GitLab