From 9f85fa004353de3ae85ee3d22505b136da200dd7 Mon Sep 17 00:00:00 2001
From: Craig Tiller <ctiller@google.com>
Date: Wed, 15 Jul 2015 12:35:25 -0700
Subject: [PATCH] Only validate metadata from the client that we know should
 exist

- it's allowed that other metadata may be picked up when sending a
  request
---
 src/python/src/grpc/_adapter/_low_test.py             |  5 ++++-
 src/python/src/grpc/_links/_transmission_test.py      |  9 +++++++--
 .../src/grpc/framework/interfaces/links/test_cases.py | 11 ++++++-----
 3 files changed, 17 insertions(+), 8 deletions(-)

diff --git a/src/python/src/grpc/_adapter/_low_test.py b/src/python/src/grpc/_adapter/_low_test.py
index 268e5fe765..a49cd007bf 100644
--- a/src/python/src/grpc/_adapter/_low_test.py
+++ b/src/python/src/grpc/_adapter/_low_test.py
@@ -129,7 +129,10 @@ class InsecureServerInsecureClient(unittest.TestCase):
     self.assertIsInstance(request_event.call, _low.Call)
     self.assertIs(server_request_tag, request_event.tag)
     self.assertEquals(1, len(request_event.results))
-    self.assertEquals(dict(client_initial_metadata), dict(request_event.results[0].initial_metadata))
+    got_initial_metadata = dict(request_event.results[0].initial_metadata)
+    self.assertEquals(
+        dict(client_initial_metadata),
+        dict((x, got_initial_metadata[x]) for x in zip(*client_initial_metadata)[0]))
     self.assertEquals(METHOD, request_event.call_details.method)
     self.assertEquals(HOST, request_event.call_details.host)
     self.assertLess(abs(DEADLINE - request_event.call_details.deadline), DEADLINE_TOLERANCE)
diff --git a/src/python/src/grpc/_links/_transmission_test.py b/src/python/src/grpc/_links/_transmission_test.py
index c5ef1edb25..d551a1fd5f 100644
--- a/src/python/src/grpc/_links/_transmission_test.py
+++ b/src/python/src/grpc/_links/_transmission_test.py
@@ -93,8 +93,13 @@ class TransmissionTest(test_cases.TransmissionTest, unittest.TestCase):
   def create_service_completion(self):
     return _intermediary_low.Code.OK, 'An exuberant test "details" message!'
 
-  def assertMetadataEqual(self, original_metadata, transmitted_metadata):
-    self.assertSequenceEqual(original_metadata, transmitted_metadata)
+  def assertMetadataTransmitted(self, original_metadata, transmitted_metadata):
+    # we need to filter out any additional metadata added in transmitted_metadata
+    # since implementations are allowed to add to what is sent (in any position)
+    keys, _ = zip(*original_metadata)
+    self.assertSequenceEqual(
+        original_metadata, 
+        (x for x in transmitted_metadata if x[0] in keys))
 
 
 class RoundTripTest(unittest.TestCase):
diff --git a/src/python/src/grpc/framework/interfaces/links/test_cases.py b/src/python/src/grpc/framework/interfaces/links/test_cases.py
index 3ac212ebdf..bf1f09d99d 100644
--- a/src/python/src/grpc/framework/interfaces/links/test_cases.py
+++ b/src/python/src/grpc/framework/interfaces/links/test_cases.py
@@ -161,8 +161,8 @@ class TransmissionTest(object):
     raise NotImplementedError()
 
   @abc.abstractmethod
-  def assertMetadataEqual(self, original_metadata, transmitted_metadata):
-    """Asserts that two metadata objects are equal.
+  def assertMetadataTransmitted(self, original_metadata, transmitted_metadata):
+    """Asserts that transmitted_metadata contains original_metadata.
 
     Args:
       original_metadata: A metadata object used in this test.
@@ -170,7 +170,8 @@ class TransmissionTest(object):
         through the system under test.
 
     Raises:
-      AssertionError: if the two metadata objects are not equal.
+      AssertionError: if the transmitted_metadata object does not contain
+        original_metadata.
     """
     raise NotImplementedError()
 
@@ -239,7 +240,7 @@ class TransmissionTest(object):
         self.assertFalse(initial_metadata_seen)
         self.assertFalse(seen_payloads)
         self.assertFalse(terminal_metadata_seen)
-        self.assertMetadataEqual(initial_metadata, ticket.initial_metadata)
+        self.assertMetadataTransmitted(initial_metadata, ticket.initial_metadata)
         initial_metadata_seen = True
 
       if ticket.payload is not None:
@@ -248,7 +249,7 @@ class TransmissionTest(object):
 
       if ticket.terminal_metadata is not None:
         self.assertFalse(terminal_metadata_seen)
-        self.assertMetadataEqual(terminal_metadata, ticket.terminal_metadata)
+        self.assertMetadataTransmitted(terminal_metadata, ticket.terminal_metadata)
         terminal_metadata_seen = True
     self.assertSequenceEqual(payloads, seen_payloads)
 
-- 
GitLab