From 3045a379aa76ce9ee930f427daa4ee799b0162aa Mon Sep 17 00:00:00 2001
From: Ken Payson <kpayson@google.com>
Date: Wed, 9 Nov 2016 17:56:33 -0800
Subject: [PATCH] Add configurable exit grace periods and shutdown handlers

The server cleanup method is untested.

The join() function that exposes it is only called by the internals of threading.py, and we don't hold a reference to the server thread to explicitly join() it, and I'm not sure we should add a reference just for this purpose.

Moreover, the threading.py only calls join(None), the code path in question isn't even exercised by the internals of threading.py. Its just there to make sure we properly follow the join(timeout) semantics.
---
 src/python/grpcio/grpc/__init__.py            | 30 ++++++-
 src/python/grpcio/grpc/_server.py             | 88 +++++++++++--------
 src/python/grpcio_tests/tests/tests.json      |  1 +
 .../grpcio_tests/tests/unit/_exit_test.py     | 23 ++++-
 4 files changed, 102 insertions(+), 40 deletions(-)

diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py
index 4e4062bafc..6087276d51 100644
--- a/src/python/grpcio/grpc/__init__.py
+++ b/src/python/grpcio/grpc/__init__.py
@@ -904,6 +904,21 @@ class Server(six.with_metaclass(abc.ABCMeta)):
     """
     raise NotImplementedError()
 
+  @abc.abstractmethod
+  def add_shutdown_handler(self, shutdown_handler):
+    """Adds a handler to be called on server shutdown.
+
+    Shutdown handlers are run on server stop() or in the event that a running
+    server is destroyed unexpectedly.  The handlers are run in series before
+    the stop grace period.
+
+    Args:
+      shutdown_handler:  A function taking a single arg, a time in seconds
+      within which the handler should complete.  None indicates the handler can
+      run for any duration.
+    """
+    raise NotImplementedError()
+
   @abc.abstractmethod
   def start(self):
     """Starts this Server's service of RPCs.
@@ -914,7 +929,7 @@ class Server(six.with_metaclass(abc.ABCMeta)):
     raise NotImplementedError()
 
   @abc.abstractmethod
-  def stop(self, grace):
+  def stop(self, grace, shutdown_handler_grace=None):
     """Stops this Server's service of RPCs.
 
     All calls to this method immediately stop service of new RPCs. When existing
@@ -937,6 +952,8 @@ class Server(six.with_metaclass(abc.ABCMeta)):
         aborted by this Server's stopping. If None, all RPCs will be aborted
         immediately and this method will block until this Server is completely
         stopped.
+      shutdown_handler_grace:  A duration of time in seconds or None.  This
+        value is passed to all shutdown handlers.
 
     Returns:
       A threading.Event that will be set when this Server has completely
@@ -1231,7 +1248,8 @@ def secure_channel(target, credentials, options=None):
                           credentials._credentials)
 
 
-def server(thread_pool, handlers=None, options=None):
+def server(thread_pool, handlers=None, options=None, exit_grace=None,
+           exit_shutdown_handler_grace=None):
   """Creates a Server with which RPCs can be serviced.
 
   Args:
@@ -1244,13 +1262,19 @@ def server(thread_pool, handlers=None, options=None):
       returned Server is started.
     options: A sequence of string-value pairs according to which to configure
       the created server.
+    exit_grace:  The grace period to use when terminating
+      running servers at interpreter exit.  None indicates unspecified.
+    exit_shutdown_handler_grace:  The shutdown handler grace to use when
+      terminating running servers at interpreter exit.  None indicates
+      unspecified.
 
   Returns:
     A Server with which RPCs can be serviced.
   """
   from grpc import _server
   return _server.Server(thread_pool, () if handlers is None else handlers,
-                        () if options is None else options)
+                        () if options is None else options, exit_grace,
+                        exit_shutdown_handler_grace)
 
 
 ###################################  __all__  #################################
diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py
index 5223712dfa..d83a2e6ded 100644
--- a/src/python/grpcio/grpc/_server.py
+++ b/src/python/grpcio/grpc/_server.py
@@ -60,7 +60,8 @@ _CANCELLED = 'cancelled'
 _EMPTY_FLAGS = 0
 _EMPTY_METADATA = cygrpc.Metadata(())
 
-_UNEXPECTED_EXIT_SERVER_GRACE = 1.0
+_DEFAULT_EXIT_GRACE = 1.0
+_DEFAULT_EXIT_SHUTDOWN_HANDLER_GRACE = 5.0
 
 
 def _serialized_request(request_event):
@@ -595,14 +596,18 @@ class _ServerStage(enum.Enum):
 
 class _ServerState(object):
 
-  def __init__(self, completion_queue, server, generic_handlers, thread_pool):
+  def __init__(self, completion_queue, server, generic_handlers, thread_pool,
+               exit_grace, exit_shutdown_handler_grace):
     self.lock = threading.Lock()
     self.completion_queue = completion_queue
     self.server = server
     self.generic_handlers = list(generic_handlers)
     self.thread_pool = thread_pool
+    self.exit_grace = exit_grace
+    self.exit_shutdown_handler_grace = exit_shutdown_handler_grace
     self.stage = _ServerStage.STOPPED
     self.shutdown_events = None
+    self.shutdown_handlers = []
 
     # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
     self.rpc_states = set()
@@ -672,41 +677,45 @@ def _serve(state):
             return
 
 
-def _stop(state, grace):
-  with state.lock:
-    if state.stage is _ServerStage.STOPPED:
-      shutdown_event = threading.Event()
-      shutdown_event.set()
-      return shutdown_event
-    else:
-      if state.stage is _ServerStage.STARTED:
-        state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
+def _stop(state, grace, shutdown_handler_grace):
+  shutdown_event = threading.Event()
+
+  def cancel_all_calls_after_grace():
+    with state.lock:
+      if state.stage is _ServerStage.STOPPED:
+        shutdown_event.set()
+        return
+      elif state.stage is _ServerStage.STARTED:
+        do_shutdown = True
         state.stage = _ServerStage.GRACE
         state.shutdown_events = []
-        state.due.add(_SHUTDOWN_TAG)
-      shutdown_event = threading.Event()
+      else:
+        do_shutdown = False
       state.shutdown_events.append(shutdown_event)
-      if grace is None:
+
+    if do_shutdown:
+      # Run Shutdown Handlers without the lock
+      for handler in state.shutdown_handlers:
+        handler(shutdown_handler_grace)
+      with state.lock:
+        state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
+        state.stage = _ServerStage.GRACE
+        state.due.add(_SHUTDOWN_TAG)
+
+    if not shutdown_event.wait(timeout=grace):
+      with state.lock:
         state.server.cancel_all_calls()
         # TODO(https://github.com/grpc/grpc/issues/6597): delete this loop.
         for rpc_state in state.rpc_states:
           with rpc_state.condition:
             rpc_state.client = _CANCELLED
             rpc_state.condition.notify_all()
-      else:
-        def cancel_all_calls_after_grace():
-          shutdown_event.wait(timeout=grace)
-          with state.lock:
-            state.server.cancel_all_calls()
-            # TODO(https://github.com/grpc/grpc/issues/6597): delete this loop.
-            for rpc_state in state.rpc_states:
-              with rpc_state.condition:
-                rpc_state.client = _CANCELLED
-                rpc_state.condition.notify_all()
-        thread = threading.Thread(target=cancel_all_calls_after_grace)
-        thread.start()
-        return shutdown_event
-  shutdown_event.wait()
+
+  if grace is None:
+    cancel_all_calls_after_grace()
+  else:
+    threading.Thread(target=cancel_all_calls_after_grace).start()
+
   return shutdown_event
 
 
@@ -716,12 +725,12 @@ def _start(state):
       raise ValueError('Cannot start already-started server!')
     state.server.start()
     state.stage = _ServerStage.STARTED
-    _request_call(state)    
+    _request_call(state)
     def cleanup_server(timeout):
       if timeout is None:
-        _stop(state, _UNEXPECTED_EXIT_SERVER_GRACE).wait()
+        _stop(state, state.exit_grace, state.exit_shutdown_handler_grace).wait()
       else:
-        _stop(state, timeout).wait()
+        _stop(state, timeout, 0).wait()
 
     thread = _common.CleanupThread(
         cleanup_server, target=_serve, args=(state,))
@@ -729,12 +738,16 @@ def _start(state):
 
 class Server(grpc.Server):
 
-  def __init__(self, thread_pool, generic_handlers, options):
+  def __init__(self, thread_pool, generic_handlers, options, exit_grace,
+               exit_shutdown_handler_grace):
     completion_queue = cygrpc.CompletionQueue()
     server = cygrpc.Server(_common.channel_args(options))
     server.register_completion_queue(completion_queue)
     self._state = _ServerState(
-        completion_queue, server, generic_handlers, thread_pool)
+        completion_queue, server, generic_handlers, thread_pool,
+        _DEFAULT_EXIT_GRACE if exit_grace is None else exit_grace,
+        _DEFAULT_EXIT_SHUTDOWN_HANDLER_GRACE if exit_shutdown_handler_grace
+        is None else exit_shutdown_handler_grace)
 
   def add_generic_rpc_handlers(self, generic_rpc_handlers):
     _add_generic_handlers(self._state, generic_rpc_handlers)
@@ -745,11 +758,14 @@ class Server(grpc.Server):
   def add_secure_port(self, address, server_credentials):
     return _add_secure_port(self._state, _common.encode(address), server_credentials)
 
+  def add_shutdown_handler(self, handler):
+    self._state.shutdown_handlers.append(handler)
+
   def start(self):
     _start(self._state)
 
-  def stop(self, grace):
-    return _stop(self._state, grace)
+  def stop(self, grace, shutdown_handler_grace=None):
+    return _stop(self._state, grace, shutdown_handler_grace)
 
   def __del__(self):
-    _stop(self._state, None)
+    _stop(self._state, None, None)
diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json
index dd4a0257f5..04a2e44178 100644
--- a/src/python/grpcio_tests/tests/tests.json
+++ b/src/python/grpcio_tests/tests/tests.json
@@ -27,6 +27,7 @@
   "unit._cython.cygrpc_test.TypeSmokeTest",
   "unit._empty_message_test.EmptyMessageTest",
   "unit._exit_test.ExitTest",
+  "unit._exit_test.ShutdownHandlerTest",
   "unit._metadata_code_details_test.MetadataCodeDetailsTest",
   "unit._metadata_test.MetadataTest",
   "unit._rpc_test.RPCTest",
diff --git a/src/python/grpcio_tests/tests/unit/_exit_test.py b/src/python/grpcio_tests/tests/unit/_exit_test.py
index 5a4a32887c..342f5fcc10 100644
--- a/src/python/grpcio_tests/tests/unit/_exit_test.py
+++ b/src/python/grpcio_tests/tests/unit/_exit_test.py
@@ -43,6 +43,8 @@ import threading
 import time
 import unittest
 
+import grpc
+from grpc.framework.foundation import logging_pool
 from tests.unit import _exit_scenarios
 
 SCENARIO_FILE = os.path.abspath(os.path.join(
@@ -52,7 +54,7 @@ BASE_COMMAND = [INTERPRETER, SCENARIO_FILE]
 BASE_SIGTERM_COMMAND = BASE_COMMAND + ['--wait_for_interrupt']
 
 INIT_TIME = 1.0
-
+SHUTDOWN_GRACE = 5.0
 
 processes = []
 process_lock = threading.Lock()
@@ -182,5 +184,24 @@ class ExitTest(unittest.TestCase):
     interrupt_and_wait(process)
 
 
+class _ShutDownHandler(object):
+
+  def __init__(self):
+    self.seen_handler_grace = None
+
+  def shutdown_handler(self, handler_grace):
+    self.seen_handler_grace = handler_grace
+
+  
+class ShutdownHandlerTest(unittest.TestCase):
+
+  def test_shutdown_handler(self):
+    server = grpc.server(logging_pool.pool(1))
+    handler = _ShutDownHandler()
+    server.add_shutdown_handler(handler.shutdown_handler)
+    server.start()
+    server.stop(0, shutdown_handler_grace=SHUTDOWN_GRACE).wait()
+    self.assertEqual(SHUTDOWN_GRACE, handler.seen_handler_grace)
+
 if __name__ == '__main__':
   unittest.main(verbosity=2)
-- 
GitLab