From ed7428526cea1f7b54a126ef9cdb2c1456c47f2c Mon Sep 17 00:00:00 2001
From: Ken Payson <kpayson@google.com>
Date: Fri, 10 Jun 2016 13:19:31 -0700
Subject: [PATCH] Added cleanup to the server thread's join method.

When the Python Interpreter exits, it first attempts to join any
outstanding threads.  This is problematic if a server is created
as a top-level variable or referenced by a reference cycle, as join()
will hang.  This adds cleanup behavior to the server thread's join().
---
 src/python/grpcio/grpc/_common.py             |  42 +++++++
 src/python/grpcio/grpc/_server.py             |  31 +++--
 src/python/grpcio/tests/tests.json            |   1 +
 src/python/grpcio/tests/unit/_rpc_test.py     |   7 --
 .../grpcio/tests/unit/_thread_cleanup_test.py | 117 ++++++++++++++++++
 5 files changed, 180 insertions(+), 18 deletions(-)
 create mode 100644 src/python/grpcio/tests/unit/_thread_cleanup_test.py

diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py
index b8688a0149..1fd1704f18 100644
--- a/src/python/grpcio/grpc/_common.py
+++ b/src/python/grpcio/grpc/_common.py
@@ -30,6 +30,8 @@
 """Shared implementation."""
 
 import logging
+import threading
+import time
 
 import six
 
@@ -110,3 +112,43 @@ def fully_qualified_method(group, method):
   group = _encode(group)
   method = _encode(method)
   return b'/' + group + b'/' + method
+
+
+class CleanupThread(threading.Thread):
+  """A threading.Thread subclass supporting custom behavior on join().
+
+  On Python Interpreter exit, Python will attempt to join outstanding threads
+  prior to garbage collection.  We may need to do additional cleanup, and
+  we accomplish this by overriding the join() method.
+  """
+
+  def __init__(self, behavior, group=None, target=None, name=None,
+               args=(), kwargs={}):
+    """Constructor.
+
+    Args:
+      behavior (function): Function called on join() with a single
+          argument, timeout, indicating the maximum duration of
+          `behavior`, or None indicating `behavior` has no deadline.
+          `behavior` must be idempotent.
+      group (None): should be None.  Reseved for future extensions
+          when ThreadGroup is implemented.
+      target (function): The function to invoke when this thread is
+          run.  Defaults to None.
+      name (str): The name of this thread.  Defaults to None.
+        args (tuple[object]): A tuple of arguments to pass to `target`.
+      kwargs (dict[str,object]): A dictionary of keyword arguments to
+           pass to `target`.
+    """
+    super(CleanupThread, self).__init__(group=group, target=target,
+                                        name=name, args=args, kwargs=kwargs)
+    self._behavior = behavior
+
+  def join(self, timeout=None):
+    start_time = time.time()
+    self._behavior(timeout)
+    end_time = time.time()
+    if timeout is not None:
+      timeout -= end_time - start_time
+      timeout = max(timeout, 0)
+    super(CleanupThread, self).join(timeout)
diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py
index f4f6720497..2d9b96afbf 100644
--- a/src/python/grpcio/grpc/_server.py
+++ b/src/python/grpcio/grpc/_server.py
@@ -60,6 +60,8 @@ _CANCELLED = 'cancelled'
 _EMPTY_FLAGS = 0
 _EMPTY_METADATA = cygrpc.Metadata(())
 
+_UNEXPECTED_EXIT_SERVER_GRACE = 1.0
+
 
 def _serialized_request(request_event):
   return request_event.batch_operations[0].received_message.bytes()
@@ -670,17 +672,6 @@ def _serve(state):
             return
 
 
-def _start(state):
-  with state.lock:
-    if state.stage is not _ServerStage.STOPPED:
-      raise ValueError('Cannot start already-started server!')
-    state.server.start()
-    state.stage = _ServerStage.STARTED
-    _request_call(state)
-    thread = threading.Thread(target=_serve, args=(state,))
-    thread.start()
-
-
 def _stop(state, grace):
   with state.lock:
     if state.stage is _ServerStage.STOPPED:
@@ -719,6 +710,24 @@ def _stop(state, grace):
   return shutdown_event
 
 
+def _start(state):
+  with state.lock:
+    if state.stage is not _ServerStage.STOPPED:
+      raise ValueError('Cannot start already-started server!')
+    state.server.start()
+    state.stage = _ServerStage.STARTED
+    _request_call(state)    
+    def cleanup_server(timeout):
+      if timeout is None:
+        _stop(state, _UNEXPECTED_EXIT_SERVER_GRACE).wait()
+      else:
+        _stop(state, timeout).wait()
+
+    thread = _common.CleanupThread(
+        cleanup_server, target=_serve, args=(state,))
+    thread.start()
+
+
 class Server(grpc.Server):
 
   def __init__(self, generic_handlers, thread_pool):
diff --git a/src/python/grpcio/tests/tests.json b/src/python/grpcio/tests/tests.json
index 8dc47bf69d..fcf2001b80 100644
--- a/src/python/grpcio/tests/tests.json
+++ b/src/python/grpcio/tests/tests.json
@@ -52,6 +52,7 @@
   "_rpc_test.RPCTest",
   "_sanity_test.Sanity", 
   "_secure_interop_test.SecureInteropTest", 
+  "_thread_cleanup_test.CleanupThreadTest",
   "_transmission_test.RoundTripTest", 
   "_transmission_test.TransmissionTest", 
   "_utilities_test.ChannelConnectivityTest", 
diff --git a/src/python/grpcio/tests/unit/_rpc_test.py b/src/python/grpcio/tests/unit/_rpc_test.py
index b33bff490c..8773e96284 100644
--- a/src/python/grpcio/tests/unit/_rpc_test.py
+++ b/src/python/grpcio/tests/unit/_rpc_test.py
@@ -193,13 +193,6 @@ class RPCTest(unittest.TestCase):
 
     self._channel = grpc.insecure_channel('localhost:%d' % port)
 
-  # TODO(nathaniel): Why is this necessary, and only in some development
-  # environments?
-  def tearDown(self):
-    del self._channel
-    del self._server
-    del self._server_pool
-
   def testUnrecognizedMethod(self):
     request = b'abc'
 
diff --git a/src/python/grpcio/tests/unit/_thread_cleanup_test.py b/src/python/grpcio/tests/unit/_thread_cleanup_test.py
new file mode 100644
index 0000000000..3e4f317edc
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_thread_cleanup_test.py
@@ -0,0 +1,117 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+#     * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#     * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+#     * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Tests for CleanupThread."""
+
+import threading
+import time
+import unittest
+
+from grpc import _common
+
+_SHORT_TIME = 0.5
+_LONG_TIME = 2.0
+_EPSILON = 0.1
+
+
+def cleanup(timeout):
+  if timeout is not None:
+    time.sleep(timeout)
+  else:
+    time.sleep(_LONG_TIME)
+
+
+def slow_cleanup(timeout):
+  # Don't respect timeout
+  time.sleep(_LONG_TIME)
+
+
+class CleanupThreadTest(unittest.TestCase):
+
+  def testTargetInvocation(self):
+    event = threading.Event()
+    def target(arg1, arg2, arg3=None):
+      self.assertEqual('arg1', arg1)
+      self.assertEqual('arg2', arg2)
+      self.assertEqual('arg3', arg3)
+      event.set()
+
+    cleanup_thread = _common.CleanupThread(behavior=lambda x: None,
+                              target=target, name='test-name',
+                              args=('arg1', 'arg2'), kwargs={'arg3': 'arg3'})
+    cleanup_thread.start()
+    cleanup_thread.join()
+    self.assertEqual(cleanup_thread.name, 'test-name')
+    self.assertTrue(event.is_set())
+
+  def testJoinNoTimeout(self):
+    cleanup_thread = _common.CleanupThread(behavior=cleanup)
+    cleanup_thread.start()
+    start_time = time.time()
+    cleanup_thread.join()
+    end_time = time.time()
+    self.assertAlmostEqual(_LONG_TIME, end_time - start_time, delta=_EPSILON)
+
+  def testJoinTimeout(self):
+    cleanup_thread = _common.CleanupThread(behavior=cleanup)
+    cleanup_thread.start()
+    start_time = time.time()
+    cleanup_thread.join(_SHORT_TIME)
+    end_time = time.time()
+    self.assertAlmostEqual(_SHORT_TIME, end_time - start_time, delta=_EPSILON)
+
+  def testJoinTimeoutSlowBehavior(self):
+    cleanup_thread = _common.CleanupThread(behavior=slow_cleanup)
+    cleanup_thread.start()
+    start_time = time.time()
+    cleanup_thread.join(_SHORT_TIME)
+    end_time = time.time()
+    self.assertAlmostEqual(_LONG_TIME, end_time - start_time, delta=_EPSILON)
+
+  def testJoinTimeoutSlowTarget(self):
+    event = threading.Event()
+    def target():
+      event.wait(_LONG_TIME)
+    cleanup_thread = _common.CleanupThread(behavior=cleanup, target=target)
+    cleanup_thread.start()
+    start_time = time.time()
+    cleanup_thread.join(_SHORT_TIME)
+    end_time = time.time()
+    self.assertAlmostEqual(_SHORT_TIME, end_time - start_time, delta=_EPSILON)
+    event.set()
+
+  def testJoinZeroTimeout(self):
+    cleanup_thread = _common.CleanupThread(behavior=cleanup)
+    cleanup_thread.start()
+    start_time = time.time()
+    cleanup_thread.join(0)
+    end_time = time.time()
+    self.assertAlmostEqual(0, end_time - start_time, delta=_EPSILON)
+
+if __name__ == '__main__':
+  unittest.main(verbosity=2)
-- 
GitLab