diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 7dae90c89e884346e43d81be9a2640266132f278..e14db7906da00f236cc16f7c4e4fd31df5af61f8 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -192,6 +192,9 @@ class Future(six.with_metaclass(abc.ABCMeta)): If the computation has already completed, the callback will be called immediately. + Exceptions raised in the callback will be logged at ERROR level, but + will not terminate any threads of execution. + Args: fn: A callable taking this Future object as its single parameter. """ diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 0bf8e03b5ce2f83bc3825ff19230947f67096745..b19c64d3a6e9095c6e82a989329b9a696c352974 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -159,7 +159,14 @@ def _event_handler(state, response_deserializer): state.condition.notify_all() done = not state.due for callback in callbacks: - callback() + # TODO(gnossen): Are these *only* user callbacks? + try: + callback() + except Exception as e: # pylint: disable=broad-except + # NOTE(rbellevi): We suppress but log errors here so as not to + # kill the channel spin thread. + logging.error('Exception in callback %s: %s', repr( + callback.func), repr(e)) return done and state.fork_epoch >= cygrpc.get_fork_epoch() return handle_event @@ -338,7 +345,7 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too def add_done_callback(self, fn): with self._state.condition: if self._state.code is None: - self._state.callbacks.append(lambda: fn(self)) + self._state.callbacks.append(functools.partial(fn, self)) return fn(self) diff --git a/src/python/grpcio_tests/tests/unit/_channel_close_test.py b/src/python/grpcio_tests/tests/unit/_channel_close_test.py index 82fa1657109b660589231717a4016013a71120ec..571504c6e3f3158f666a25e26e4d75a443933231 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_close_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_close_test.py @@ -27,8 +27,11 @@ _BEAT = 0.5 _SOME_TIME = 5 _MORE_TIME = 10 +_STREAM_URI = 'Meffod' +_UNARY_URI = 'MeffodMan' -class _MethodHandler(grpc.RpcMethodHandler): + +class _StreamingMethodHandler(grpc.RpcMethodHandler): request_streaming = True response_streaming = True @@ -40,13 +43,28 @@ class _MethodHandler(grpc.RpcMethodHandler): yield request * 2 -_METHOD_HANDLER = _MethodHandler() +class _UnaryMethodHandler(grpc.RpcMethodHandler): + + request_streaming = False + response_streaming = False + request_deserializer = None + response_serializer = None + + def unary_unary(self, request, servicer_context): + return request * 2 + + +_STREAMING_METHOD_HANDLER = _StreamingMethodHandler() +_UNARY_METHOD_HANDLER = _UnaryMethodHandler() class _GenericHandler(grpc.GenericRpcHandler): def service(self, handler_call_details): - return _METHOD_HANDLER + if handler_call_details.method == _STREAM_URI: + return _STREAMING_METHOD_HANDLER + else: + return _UNARY_METHOD_HANDLER _GENERIC_HANDLER = _GenericHandler() @@ -94,6 +112,24 @@ class _Pipe(object): self.close() +class EndlessIterator(object): + + def __init__(self, msg): + self._msg = msg + + def __iter__(self): + return self + + def _next(self): + return self._msg + + def __next__(self): + return self._next() + + def next(self): + return self._next() + + class ChannelCloseTest(unittest.TestCase): def setUp(self): @@ -108,7 +144,7 @@ class ChannelCloseTest(unittest.TestCase): def test_close_immediately_after_call_invocation(self): channel = grpc.insecure_channel('localhost:{}'.format(self._port)) - multi_callable = channel.stream_stream('Meffod') + multi_callable = channel.stream_stream(_STREAM_URI) request_iterator = _Pipe(()) response_iterator = multi_callable(request_iterator) channel.close() @@ -118,7 +154,7 @@ class ChannelCloseTest(unittest.TestCase): def test_close_while_call_active(self): channel = grpc.insecure_channel('localhost:{}'.format(self._port)) - multi_callable = channel.stream_stream('Meffod') + multi_callable = channel.stream_stream(_STREAM_URI) request_iterator = _Pipe((b'abc',)) response_iterator = multi_callable(request_iterator) next(response_iterator) @@ -130,7 +166,7 @@ class ChannelCloseTest(unittest.TestCase): def test_context_manager_close_while_call_active(self): with grpc.insecure_channel('localhost:{}'.format( self._port)) as channel: # pylint: disable=bad-continuation - multi_callable = channel.stream_stream('Meffod') + multi_callable = channel.stream_stream(_STREAM_URI) request_iterator = _Pipe((b'abc',)) response_iterator = multi_callable(request_iterator) next(response_iterator) @@ -141,7 +177,7 @@ class ChannelCloseTest(unittest.TestCase): def test_context_manager_close_while_many_calls_active(self): with grpc.insecure_channel('localhost:{}'.format( self._port)) as channel: # pylint: disable=bad-continuation - multi_callable = channel.stream_stream('Meffod') + multi_callable = channel.stream_stream(_STREAM_URI) request_iterators = tuple( _Pipe((b'abc',)) for _ in range(test_constants.THREAD_CONCURRENCY)) @@ -158,7 +194,7 @@ class ChannelCloseTest(unittest.TestCase): def test_many_concurrent_closes(self): channel = grpc.insecure_channel('localhost:{}'.format(self._port)) - multi_callable = channel.stream_stream('Meffod') + multi_callable = channel.stream_stream(_STREAM_URI) request_iterator = _Pipe((b'abc',)) response_iterator = multi_callable(request_iterator) next(response_iterator) @@ -181,6 +217,21 @@ class ChannelCloseTest(unittest.TestCase): self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + def test_exception_in_callback(self): + with grpc.insecure_channel('localhost:{}'.format( + self._port)) as channel: + stream_multi_callable = channel.stream_stream(_STREAM_URI) + request_iterator = (str(i).encode('ascii') for i in range(9999)) + endless_iterator = EndlessIterator(b'abc') + stream_response_iterator = stream_multi_callable(endless_iterator) + future = channel.unary_unary(_UNARY_URI).future(b'abc') + + def on_done_callback(future): + raise Exception("This should not cause a deadlock.") + + future.add_done_callback(on_done_callback) + future.result() + if __name__ == '__main__': logging.basicConfig()