Skip to content
Snippets Groups Projects
Commit b7d49b6b authored by Masood Malekghassemi's avatar Masood Malekghassemi
Browse files

Merge pull request #3328 from nathanielmanistaatgoogle/future-callbacks

Test coverage for callbacks added to Face futures
parents 2449d312 0c617767
No related branches found
No related tags found
No related merge requests found
...@@ -72,6 +72,36 @@ class _PauseableIterator(object): ...@@ -72,6 +72,36 @@ class _PauseableIterator(object):
return next(self._upstream) return next(self._upstream)
class _Callback(object):
def __init__(self):
self._condition = threading.Condition()
self._called = False
self._passed_future = None
self._passed_other_stuff = None
def __call__(self, *args, **kwargs):
with self._condition:
self._called = True
if args:
self._passed_future = args[0]
if 1 < len(args) or kwargs:
self._passed_other_stuff = tuple(args[1:]), dict(kwargs)
self._condition.notify_all()
def future(self):
with self._condition:
while True:
if self._passed_other_stuff is not None:
raise ValueError(
'Test callback passed unexpected values: %s',
self._passed_other_stuff)
elif self._called:
return self._passed_future
else:
self._condition.wait()
class TestCase(test_coverage.Coverage, unittest.TestCase): class TestCase(test_coverage.Coverage, unittest.TestCase):
"""A test of the Face layer of RPC Framework. """A test of the Face layer of RPC Framework.
...@@ -112,12 +142,15 @@ class TestCase(test_coverage.Coverage, unittest.TestCase): ...@@ -112,12 +142,15 @@ class TestCase(test_coverage.Coverage, unittest.TestCase):
self._digest.unary_unary_messages_sequences.iteritems()): self._digest.unary_unary_messages_sequences.iteritems()):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
callback = _Callback()
response_future = self._invoker.future(group, method)( response_future = self._invoker.future(group, method)(
request, test_constants.LONG_TIMEOUT) request, test_constants.LONG_TIMEOUT)
response_future.add_done_callback(callback)
response = response_future.result() response = response_future.result()
test_messages.verify(request, response, self) test_messages.verify(request, response, self)
self.assertIs(callback.future(), response_future)
def testSuccessfulUnaryRequestStreamResponse(self): def testSuccessfulUnaryRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (
...@@ -137,15 +170,19 @@ class TestCase(test_coverage.Coverage, unittest.TestCase): ...@@ -137,15 +170,19 @@ class TestCase(test_coverage.Coverage, unittest.TestCase):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
request_iterator = _PauseableIterator(iter(requests)) request_iterator = _PauseableIterator(iter(requests))
callback = _Callback()
# Use of a paused iterator of requests allows us to test that control is # Use of a paused iterator of requests allows us to test that control is
# returned to calling code before the iterator yields any requests. # returned to calling code before the iterator yields any requests.
with request_iterator.pause(): with request_iterator.pause():
response_future = self._invoker.future(group, method)( response_future = self._invoker.future(group, method)(
request_iterator, test_constants.LONG_TIMEOUT) request_iterator, test_constants.LONG_TIMEOUT)
response = response_future.result() response_future.add_done_callback(callback)
future_passed_to_callback = callback.future()
response = future_passed_to_callback.result()
test_messages.verify(requests, response, self) test_messages.verify(requests, response, self)
self.assertIs(future_passed_to_callback, response_future)
def testSuccessfulStreamRequestStreamResponse(self): def testSuccessfulStreamRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (
...@@ -208,12 +245,15 @@ class TestCase(test_coverage.Coverage, unittest.TestCase): ...@@ -208,12 +245,15 @@ class TestCase(test_coverage.Coverage, unittest.TestCase):
self._digest.unary_unary_messages_sequences.iteritems()): self._digest.unary_unary_messages_sequences.iteritems()):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
callback = _Callback()
with self._control.pause(): with self._control.pause():
response_future = self._invoker.future(group, method)( response_future = self._invoker.future(group, method)(
request, test_constants.LONG_TIMEOUT) request, test_constants.LONG_TIMEOUT)
response_future.add_done_callback(callback)
cancel_method_return_value = response_future.cancel() cancel_method_return_value = response_future.cancel()
self.assertIs(callback.future(), response_future)
self.assertFalse(cancel_method_return_value) self.assertFalse(cancel_method_return_value)
self.assertTrue(response_future.cancelled()) self.assertTrue(response_future.cancelled())
...@@ -236,12 +276,15 @@ class TestCase(test_coverage.Coverage, unittest.TestCase): ...@@ -236,12 +276,15 @@ class TestCase(test_coverage.Coverage, unittest.TestCase):
self._digest.stream_unary_messages_sequences.iteritems()): self._digest.stream_unary_messages_sequences.iteritems()):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
callback = _Callback()
with self._control.pause(): with self._control.pause():
response_future = self._invoker.future(group, method)( response_future = self._invoker.future(group, method)(
iter(requests), test_constants.LONG_TIMEOUT) iter(requests), test_constants.LONG_TIMEOUT)
response_future.add_done_callback(callback)
cancel_method_return_value = response_future.cancel() cancel_method_return_value = response_future.cancel()
self.assertIs(callback.future(), response_future)
self.assertFalse(cancel_method_return_value) self.assertFalse(cancel_method_return_value)
self.assertTrue(response_future.cancelled()) self.assertTrue(response_future.cancelled())
...@@ -264,10 +307,13 @@ class TestCase(test_coverage.Coverage, unittest.TestCase): ...@@ -264,10 +307,13 @@ class TestCase(test_coverage.Coverage, unittest.TestCase):
self._digest.unary_unary_messages_sequences.iteritems()): self._digest.unary_unary_messages_sequences.iteritems()):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
callback = _Callback()
with self._control.pause(): with self._control.pause():
response_future = self._invoker.future( response_future = self._invoker.future(
group, method)(request, _3069_test_constant.REALLY_SHORT_TIMEOUT) group, method)(request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
response_future.add_done_callback(callback)
self.assertIs(callback.future(), response_future)
self.assertIsInstance( self.assertIsInstance(
response_future.exception(), face.ExpirationError) response_future.exception(), face.ExpirationError)
with self.assertRaises(face.ExpirationError): with self.assertRaises(face.ExpirationError):
...@@ -290,10 +336,13 @@ class TestCase(test_coverage.Coverage, unittest.TestCase): ...@@ -290,10 +336,13 @@ class TestCase(test_coverage.Coverage, unittest.TestCase):
self._digest.stream_unary_messages_sequences.iteritems()): self._digest.stream_unary_messages_sequences.iteritems()):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
callback = _Callback()
with self._control.pause(): with self._control.pause():
response_future = self._invoker.future(group, method)( response_future = self._invoker.future(group, method)(
iter(requests), _3069_test_constant.REALLY_SHORT_TIMEOUT) iter(requests), _3069_test_constant.REALLY_SHORT_TIMEOUT)
response_future.add_done_callback(callback)
self.assertIs(callback.future(), response_future)
self.assertIsInstance( self.assertIsInstance(
response_future.exception(), face.ExpirationError) response_future.exception(), face.ExpirationError)
with self.assertRaises(face.ExpirationError): with self.assertRaises(face.ExpirationError):
...@@ -316,11 +365,14 @@ class TestCase(test_coverage.Coverage, unittest.TestCase): ...@@ -316,11 +365,14 @@ class TestCase(test_coverage.Coverage, unittest.TestCase):
self._digest.unary_unary_messages_sequences.iteritems()): self._digest.unary_unary_messages_sequences.iteritems()):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
callback = _Callback()
with self._control.fail(): with self._control.fail():
response_future = self._invoker.future(group, method)( response_future = self._invoker.future(group, method)(
request, _3069_test_constant.REALLY_SHORT_TIMEOUT) request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
response_future.add_done_callback(callback)
self.assertIs(callback.future(), response_future)
# Because the servicer fails outside of the thread from which the # Because the servicer fails outside of the thread from which the
# servicer-side runtime called into it its failure is # servicer-side runtime called into it its failure is
# indistinguishable from simply not having called its # indistinguishable from simply not having called its
...@@ -350,11 +402,14 @@ class TestCase(test_coverage.Coverage, unittest.TestCase): ...@@ -350,11 +402,14 @@ class TestCase(test_coverage.Coverage, unittest.TestCase):
self._digest.stream_unary_messages_sequences.iteritems()): self._digest.stream_unary_messages_sequences.iteritems()):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
callback = _Callback()
with self._control.fail(): with self._control.fail():
response_future = self._invoker.future(group, method)( response_future = self._invoker.future(group, method)(
iter(requests), _3069_test_constant.REALLY_SHORT_TIMEOUT) iter(requests), _3069_test_constant.REALLY_SHORT_TIMEOUT)
response_future.add_done_callback(callback)
self.assertIs(callback.future(), response_future)
# Because the servicer fails outside of the thread from which the # Because the servicer fails outside of the thread from which the
# servicer-side runtime called into it its failure is # servicer-side runtime called into it its failure is
# indistinguishable from simply not having called its # indistinguishable from simply not having called its
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment