From 99355ce1cc6926d6a39b6660e8c60ad49e53d510 Mon Sep 17 00:00:00 2001
From: Nathaniel Manista <nathaniel@google.com>
Date: Mon, 31 Aug 2015 05:11:52 +0000
Subject: [PATCH] Metadata plumbing and serialization tweaks

(1) Plumb the metadata transformer given at the Beta API through to the
InvocationLink where it will be used.

(2) In both InvocationLink and ServiceLink, if there isn't a registered
serializer or deserializer, just pass the payload through rather than
ignoring the entire RPC.
---
 src/python/grpcio/grpc/_links/invocation.py   | 43 ++++++++++++-------
 src/python/grpcio/grpc/_links/service.py      | 16 ++++---
 src/python/grpcio/grpc/beta/_stub.py          | 24 ++++++-----
 src/python/grpcio/grpc/beta/beta.py           |  4 +-
 .../_core_over_links_base_interface_test.py   |  2 +-
 ...ver_core_over_links_face_interface_test.py |  2 +-
 .../_links/_lonely_invocation_link_test.py    |  6 +--
 .../grpc_test/_links/_transmission_test.py    |  6 +--
 8 files changed, 60 insertions(+), 43 deletions(-)

diff --git a/src/python/grpcio/grpc/_links/invocation.py b/src/python/grpcio/grpc/_links/invocation.py
index 729b987dd1..1676fe7941 100644
--- a/src/python/grpcio/grpc/_links/invocation.py
+++ b/src/python/grpcio/grpc/_links/invocation.py
@@ -41,6 +41,8 @@ from grpc.framework.foundation import logging_pool
 from grpc.framework.foundation import relay
 from grpc.framework.interfaces.links import links
 
+_IDENTITY = lambda x: x
+
 _STOP = _intermediary_low.Event.Kind.STOP
 _WRITE = _intermediary_low.Event.Kind.WRITE_ACCEPTED
 _COMPLETE = _intermediary_low.Event.Kind.COMPLETE_ACCEPTED
@@ -95,11 +97,12 @@ def _no_longer_due(kind, rpc_state, key, rpc_states):
 class _Kernel(object):
 
   def __init__(
-      self, channel, host, request_serializers, response_deserializers,
-      ticket_relay):
+      self, channel, host, metadata_transformer, request_serializers,
+      response_deserializers, ticket_relay):
     self._lock = threading.Lock()
     self._channel = channel
     self._host = host
+    self._metadata_transformer = metadata_transformer
     self._request_serializers = request_serializers
     self._response_deserializers = response_deserializers
     self._relay = ticket_relay
@@ -225,20 +228,17 @@ class _Kernel(object):
     else:
       return
 
-    request_serializer = self._request_serializers.get((group, method))
-    response_deserializer = self._response_deserializers.get((group, method))
-    if request_serializer is None or response_deserializer is None:
-      cancellation_ticket = links.Ticket(
-          operation_id, 0, None, None, None, None, None, None, None, None, None,
-          None, links.Ticket.Termination.CANCELLATION)
-      self._relay.add_value(cancellation_ticket)
-      return
+    transformed_initial_metadata = self._metadata_transformer(initial_metadata)
+    request_serializer = self._request_serializers.get(
+        (group, method), _IDENTITY)
+    response_deserializer = self._response_deserializers.get(
+        (group, method), _IDENTITY)
 
     call = _intermediary_low.Call(
         self._channel, self._completion_queue, '/%s/%s' % (group, method),
         self._host, time.time() + timeout)
-    if initial_metadata is not None:
-      for metadata_key, metadata_value in initial_metadata:
+    if transformed_initial_metadata is not None:
+      for metadata_key, metadata_value in transformed_initial_metadata:
         call.add_metadata(metadata_key, metadata_value)
     call.invoke(self._completion_queue, operation_id, operation_id)
     if payload is None:
@@ -336,10 +336,15 @@ class InvocationLink(links.Link, activated.Activated):
 class _InvocationLink(InvocationLink):
 
   def __init__(
-      self, channel, host, request_serializers, response_deserializers):
+      self, channel, host, metadata_transformer, request_serializers,
+      response_deserializers):
     self._relay = relay.relay(None)
     self._kernel = _Kernel(
-        channel, host, request_serializers, response_deserializers, self._relay)
+        channel, host,
+        _IDENTITY if metadata_transformer is None else metadata_transformer,
+        {} if request_serializers is None else request_serializers,
+        {} if response_deserializers is None else response_deserializers,
+        self._relay)
 
   def _start(self):
     self._relay.start()
@@ -376,12 +381,17 @@ class _InvocationLink(InvocationLink):
     self._stop()
 
 
-def invocation_link(channel, host, request_serializers, response_deserializers):
+def invocation_link(
+    channel, host, metadata_transformer, request_serializers,
+    response_deserializers):
   """Creates an InvocationLink.
 
   Args:
     channel: An _intermediary_low.Channel for use by the link.
     host: The host to specify when invoking RPCs.
+    metadata_transformer: A callable that takes an invocation-side initial
+      metadata value and returns another metadata value to send in its place.
+      May be None.
     request_serializers: A dict from group-method pair to request object
       serialization behavior.
     response_deserializers: A dict from group-method pair to response object
@@ -391,4 +401,5 @@ def invocation_link(channel, host, request_serializers, response_deserializers):
     An InvocationLink.
   """
   return _InvocationLink(
-      channel, host, request_serializers, response_deserializers)
+      channel, host, metadata_transformer, request_serializers,
+      response_deserializers)
diff --git a/src/python/grpcio/grpc/_links/service.py b/src/python/grpcio/grpc/_links/service.py
index bbfe9bcd55..94e7cfc716 100644
--- a/src/python/grpcio/grpc/_links/service.py
+++ b/src/python/grpcio/grpc/_links/service.py
@@ -40,6 +40,8 @@ from grpc.framework.foundation import logging_pool
 from grpc.framework.foundation import relay
 from grpc.framework.interfaces.links import links
 
+_IDENTITY = lambda x: x
+
 _TERMINATION_KIND_TO_CODE = {
     links.Ticket.Termination.COMPLETION: _intermediary_low.Code.OK,
     links.Ticket.Termination.CANCELLATION: _intermediary_low.Code.CANCELLED,
@@ -154,12 +156,10 @@ class _Kernel(object):
     except ValueError:
       logging.info('Illegal path "%s"!', service_acceptance.method)
       return
-    request_deserializer = self._request_deserializers.get((group, method))
-    response_serializer = self._response_serializers.get((group, method))
-    if request_deserializer is None or response_serializer is None:
-      # TODO(nathaniel): Terminate the RPC with code NOT_FOUND.
-      call.cancel()
-      return
+    request_deserializer = self._request_deserializers.get(
+        (group, method), _IDENTITY)
+    response_serializer = self._response_serializers.get(
+        (group, method), _IDENTITY)
 
     call.read(call)
     self._rpc_states[call] = _RPCState(
@@ -433,7 +433,9 @@ class _ServiceLink(ServiceLink):
   def __init__(self, request_deserializers, response_serializers):
     self._relay = relay.relay(None)
     self._kernel = _Kernel(
-        request_deserializers, response_serializers, self._relay)
+        {} if request_deserializers is None else request_deserializers,
+        {} if response_serializers is None else response_serializers,
+        self._relay)
 
   def accept_ticket(self, ticket):
     self._kernel.add_ticket(ticket)
diff --git a/src/python/grpcio/grpc/beta/_stub.py b/src/python/grpcio/grpc/beta/_stub.py
index 178f06d21e..cfbecb852b 100644
--- a/src/python/grpcio/grpc/beta/_stub.py
+++ b/src/python/grpcio/grpc/beta/_stub.py
@@ -54,11 +54,12 @@ class _AutoIntermediary(object):
 
 
 def _assemble(
-    channel, host, request_serializers, response_deserializers, thread_pool,
-    thread_pool_size):
+    channel, host, metadata_transformer, request_serializers,
+    response_deserializers, thread_pool, thread_pool_size):
   end_link = _core_implementations.invocation_end_link()
   grpc_link = invocation.invocation_link(
-      channel, host, request_serializers, response_deserializers)
+      channel, host, metadata_transformer, request_serializers,
+      response_deserializers)
   if thread_pool is None:
     invocation_pool = logging_pool.pool(
         _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size)
@@ -89,21 +90,22 @@ def _wrap_assembly(stub, end_link, grpc_link, assembly_pool):
 
 
 def generic_stub(
-    channel, host, request_serializers, response_deserializers, thread_pool,
-    thread_pool_size):
+    channel, host, metadata_transformer, request_serializers,
+    response_deserializers, thread_pool, thread_pool_size):
   end_link, grpc_link, invocation_pool, assembly_pool = _assemble(
-      channel, host, request_serializers, response_deserializers, thread_pool,
-      thread_pool_size)
+      channel, host, metadata_transformer, request_serializers,
+      response_deserializers, thread_pool, thread_pool_size)
   stub = _crust_implementations.generic_stub(end_link, invocation_pool)
   return _wrap_assembly(stub, end_link, grpc_link, assembly_pool)
 
 
 def dynamic_stub(
-    channel, host, service, cardinalities, request_serializers,
-    response_deserializers, thread_pool, thread_pool_size):
+    channel, host, service, cardinalities, metadata_transformer,
+    request_serializers, response_deserializers, thread_pool,
+    thread_pool_size):
   end_link, grpc_link, invocation_pool, assembly_pool = _assemble(
-      channel, host, request_serializers, response_deserializers, thread_pool,
-      thread_pool_size)
+      channel, host, metadata_transformer, request_serializers,
+      response_deserializers, thread_pool, thread_pool_size)
   stub = _crust_implementations.dynamic_stub(
       end_link, service, cardinalities, invocation_pool)
   return _wrap_assembly(stub, end_link, grpc_link, assembly_pool)
diff --git a/src/python/grpcio/grpc/beta/beta.py b/src/python/grpcio/grpc/beta/beta.py
index 640e4eb86b..b3a161087f 100644
--- a/src/python/grpcio/grpc/beta/beta.py
+++ b/src/python/grpcio/grpc/beta/beta.py
@@ -238,6 +238,7 @@ def generic_stub(channel, options=None):
   effective_options = _EMPTY_STUB_OPTIONS if options is None else options
   return _stub.generic_stub(
       channel._intermediary_low_channel, effective_options.host,  # pylint: disable=protected-access
+      effective_options.metadata_transformer,
       effective_options.request_serializers,
       effective_options.response_deserializers, effective_options.thread_pool,
       effective_options.thread_pool_size)
@@ -260,7 +261,8 @@ def dynamic_stub(channel, service, cardinalities, options=None):
   effective_options = StubOptions() if options is None else options
   return _stub.dynamic_stub(
       channel._intermediary_low_channel, effective_options.host, service,  # pylint: disable=protected-access
-      cardinalities, effective_options.request_serializers,
+      cardinalities, effective_options.metadata_transformer,
+      effective_options.request_serializers,
       effective_options.response_deserializers, effective_options.thread_pool,
       effective_options.thread_pool_size)
 
diff --git a/src/python/grpcio_test/grpc_test/_core_over_links_base_interface_test.py b/src/python/grpcio_test/grpc_test/_core_over_links_base_interface_test.py
index 9112c34190..f0bd989ea6 100644
--- a/src/python/grpcio_test/grpc_test/_core_over_links_base_interface_test.py
+++ b/src/python/grpcio_test/grpc_test/_core_over_links_base_interface_test.py
@@ -94,7 +94,7 @@ class _Implementation(test_interfaces.Implementation):
     port = service_grpc_link.add_port('[::]:0', None)
     channel = _intermediary_low.Channel('localhost:%d' % port, None)
     invocation_grpc_link = invocation.invocation_link(
-        channel, b'localhost',
+        channel, b'localhost', None,
         serialization_behaviors.request_serializers,
         serialization_behaviors.response_deserializers)
 
diff --git a/src/python/grpcio_test/grpc_test/_crust_over_core_over_links_face_interface_test.py b/src/python/grpcio_test/grpc_test/_crust_over_core_over_links_face_interface_test.py
index 1401536503..28c0619f7c 100644
--- a/src/python/grpcio_test/grpc_test/_crust_over_core_over_links_face_interface_test.py
+++ b/src/python/grpcio_test/grpc_test/_crust_over_core_over_links_face_interface_test.py
@@ -87,7 +87,7 @@ class _Implementation(test_interfaces.Implementation):
     port = service_grpc_link.add_port('[::]:0', None)
     channel = _intermediary_low.Channel('localhost:%d' % port, None)
     invocation_grpc_link = invocation.invocation_link(
-        channel, b'localhost',
+        channel, b'localhost', None,
         serialization_behaviors.request_serializers,
         serialization_behaviors.response_deserializers)
 
diff --git a/src/python/grpcio_test/grpc_test/_links/_lonely_invocation_link_test.py b/src/python/grpcio_test/grpc_test/_links/_lonely_invocation_link_test.py
index 373a2b2a1f..8e12e8cc22 100644
--- a/src/python/grpcio_test/grpc_test/_links/_lonely_invocation_link_test.py
+++ b/src/python/grpcio_test/grpc_test/_links/_lonely_invocation_link_test.py
@@ -45,7 +45,8 @@ class LonelyInvocationLinkTest(unittest.TestCase):
 
   def testUpAndDown(self):
     channel = _intermediary_low.Channel('nonexistent:54321', None)
-    invocation_link = invocation.invocation_link(channel, 'nonexistent', {}, {})
+    invocation_link = invocation.invocation_link(
+        channel, 'nonexistent', None, {}, {})
 
     invocation_link.start()
     invocation_link.stop()
@@ -58,8 +59,7 @@ class LonelyInvocationLinkTest(unittest.TestCase):
 
     channel = _intermediary_low.Channel('nonexistent:54321', None)
     invocation_link = invocation.invocation_link(
-        channel, 'nonexistent', {(test_group, test_method): _NULL_BEHAVIOR},
-        {(test_group, test_method): _NULL_BEHAVIOR})
+        channel, 'nonexistent', None, {}, {})
     invocation_link.join_link(invocation_link_mate)
     invocation_link.start()
 
diff --git a/src/python/grpcio_test/grpc_test/_links/_transmission_test.py b/src/python/grpcio_test/grpc_test/_links/_transmission_test.py
index c114cef6a6..716323cc20 100644
--- a/src/python/grpcio_test/grpc_test/_links/_transmission_test.py
+++ b/src/python/grpcio_test/grpc_test/_links/_transmission_test.py
@@ -54,7 +54,7 @@ class TransmissionTest(test_cases.TransmissionTest, unittest.TestCase):
     service_link.start()
     channel = _intermediary_low.Channel('localhost:%d' % port, None)
     invocation_link = invocation.invocation_link(
-        channel, 'localhost',
+        channel, 'localhost', None,
         {self.group_and_method(): self.serialize_request},
         {self.group_and_method(): self.deserialize_response})
     invocation_link.start()
@@ -121,7 +121,7 @@ class RoundTripTest(unittest.TestCase):
     service_link.start()
     channel = _intermediary_low.Channel('localhost:%d' % port, None)
     invocation_link = invocation.invocation_link(
-        channel, 'localhost', identity_transformation, identity_transformation)
+        channel, None, None, identity_transformation, identity_transformation)
     invocation_mate = test_utilities.RecordingLink()
     invocation_link.join_link(invocation_mate)
     invocation_link.start()
@@ -166,7 +166,7 @@ class RoundTripTest(unittest.TestCase):
     service_link.start()
     channel = _intermediary_low.Channel('localhost:%d' % port, None)
     invocation_link = invocation.invocation_link(
-        channel, 'localhost',
+        channel, 'localhost', None,
         {(test_group, test_method): scenario.serialize_request},
         {(test_group, test_method): scenario.deserialize_response})
     invocation_mate = test_utilities.RecordingLink()
-- 
GitLab