From 75d0c420f298384ce23ab025b2ab67860512b51d Mon Sep 17 00:00:00 2001
From: "Mark D. Roth" <roth@google.com>
Date: Wed, 30 Nov 2016 11:40:57 -0800
Subject: [PATCH] Fix shutdown semantics for security handshaker.

---
 .../security/transport/security_handshaker.c  | 174 +++++++++++-------
 1 file changed, 103 insertions(+), 71 deletions(-)

diff --git a/src/core/lib/security/transport/security_handshaker.c b/src/core/lib/security/transport/security_handshaker.c
index 681826a287..964a692ae7 100644
--- a/src/core/lib/security/transport/security_handshaker.c
+++ b/src/core/lib/security/transport/security_handshaker.c
@@ -50,14 +50,23 @@
 
 typedef struct {
   grpc_handshaker base;
-  // args will be NULL when either there is no handshake in progress or
-  // when the handshaker is shutting down.
-  grpc_handshaker_args* args;
-  grpc_closure* on_handshake_done;
-  grpc_security_connector *connector;
+
+  // State set at creation time.
   tsi_handshaker *handshaker;
+  grpc_security_connector *connector;
+
   gpr_mu mu;
   gpr_refcount refs;
+
+  bool shutdown;
+  // Endpoint and read buffer to destroy after a shutdown.
+  grpc_endpoint *endpoint_to_destroy;
+  grpc_slice_buffer *read_buffer_to_destroy;
+
+  // State saved while performing the handshake.
+  grpc_handshaker_args* args;
+  grpc_closure* on_handshake_done;
+
   unsigned char *handshake_buffer;
   size_t handshake_buffer_size;
   grpc_slice_buffer left_overs;
@@ -68,18 +77,19 @@ typedef struct {
   grpc_auth_context *auth_context;
 } security_handshaker;
 
-static void on_handshake_data_received_from_peer(grpc_exec_ctx *exec_ctx,
-                                                 void *setup,
-                                                 grpc_error *error);
-
-static void on_handshake_data_sent_to_peer(grpc_exec_ctx *exec_ctx, void *setup,
-                                           grpc_error *error);
-
-static void unref_handshake(security_handshaker *h) {
+static void security_handshaker_unref(grpc_exec_ctx *exec_ctx,
+                                      security_handshaker *h) {
   if (gpr_unref(&h->refs)) {
-    if (h->handshaker != NULL) tsi_handshaker_destroy(h->handshaker);
     gpr_mu_destroy(&h->mu);
+    if (h->handshaker != NULL) tsi_handshaker_destroy(h->handshaker);
     if (h->handshake_buffer != NULL) gpr_free(h->handshake_buffer);
+    if (h->endpoint_to_destroy != NULL) {
+      grpc_endpoint_destroy(exec_ctx, h->endpoint_to_destroy);
+    }
+    if (h->read_buffer_to_destroy != NULL) {
+      grpc_slice_buffer_destroy(h->read_buffer_to_destroy);
+      gpr_free(h->read_buffer_to_destroy);
+    }
     grpc_slice_buffer_destroy(&h->left_overs);
     grpc_slice_buffer_destroy(&h->outgoing);
     GRPC_AUTH_CONTEXT_UNREF(h->auth_context, "handshake");
@@ -88,27 +98,41 @@ static void unref_handshake(security_handshaker *h) {
   }
 }
 
-static void security_handshake_done_locked(grpc_exec_ctx *exec_ctx,
-                                           security_handshaker *h,
-                                           grpc_error *error) {
+// Set args fields to NULL, saving the endpoint and read buffer for
+// later destruction.
+static void cleanup_args_for_failure_locked(security_handshaker *h) {
+  h->endpoint_to_destroy = h->args->endpoint;
+  h->args->endpoint = NULL;
+  h->read_buffer_to_destroy = h->args->read_buffer;
+  h->args->read_buffer = NULL;
+  grpc_channel_args_destroy(h->args->args);
+  h->args->args = NULL;
+}
+
+// If the handshake failed or we're shutting down, clean up and invoke the
+// callback with the error.
+static void security_handshake_failed_locked(grpc_exec_ctx *exec_ctx,
+                                             security_handshaker *h,
+                                             grpc_error *error) {
   if (error == GRPC_ERROR_NONE) {
-    grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context);
-    grpc_channel_args* tmp_args = h->args->args;
-    h->args->args =
-        grpc_channel_args_copy_and_add(tmp_args, &auth_context_arg, 1);
-    grpc_channel_args_destroy(tmp_args);
-  } else {
-    const char *msg = grpc_error_string(error);
-    gpr_log(GPR_DEBUG, "Security handshake failed: %s", msg);
-    grpc_error_free_string(msg);
+    // If we were shut down after the handshake succeeded but before an
+    // endpoint callback was invoked, we need to generate our own error.
+    error = GRPC_ERROR_CREATE("Handshaker shutdown");
+  }
+  const char *msg = grpc_error_string(error);
+  gpr_log(GPR_DEBUG, "Security handshake failed: %s", msg);
+  grpc_error_free_string(msg);
+  if (!h->shutdown) {
+    // TODO(ctiller): It is currently necessary to shutdown endpoints
+    // before destroying them, even if we know that there are no
+    // pending read/write callbacks.  This should be fixed, at which
+    // point this can be removed.
     grpc_endpoint_shutdown(exec_ctx, h->args->endpoint);
-// FIXME: clarify who should destroy...
-    //grpc_endpoint_destroy(exec_ctx, h->args->endpoint);
+    // Not shutting down, so the write failed.  Clean up before
+    // invoking the callback.
+    cleanup_args_for_failure_locked(h);
   }
-  // Clear out the read buffer before it gets passed to the transport,
-  // since any excess bytes were already copied to h->left_overs.
-  grpc_slice_buffer_reset_and_unref(h->args->read_buffer);
-  h->args = NULL;
+  // Invoke callback.
   grpc_exec_ctx_sched(exec_ctx, h->on_handshake_done, error, NULL);
 }
 
@@ -116,9 +140,8 @@ static void on_peer_checked(grpc_exec_ctx *exec_ctx, void *arg,
                             grpc_error *error) {
   security_handshaker *h = arg;
   gpr_mu_lock(&h->mu);
-  if (error != GRPC_ERROR_NONE) {
-    // Take a new ref to pass to security_handshake_done_locked().
-    GRPC_ERROR_REF(error);
+  if (error != GRPC_ERROR_NONE || h->shutdown) {
+    security_handshake_failed_locked(exec_ctx, h, GRPC_ERROR_REF(error));
     goto done;
   }
   // Get frame protector.
@@ -128,17 +151,30 @@ static void on_peer_checked(grpc_exec_ctx *exec_ctx, void *arg,
   if (result != TSI_OK) {
     error = grpc_set_tsi_error_result(
         GRPC_ERROR_CREATE("Frame protector creation failed"), result);
+    security_handshake_failed_locked(exec_ctx, h, error);
     goto done;
   }
+  // Success.
+  // Create secure endpoint.
   h->args->endpoint =
       grpc_secure_endpoint_create(protector, h->args->endpoint,
                                   h->left_overs.slices, h->left_overs.count);
   h->left_overs.count = 0;
   h->left_overs.length = 0;
+  // Clear out the read buffer before it gets passed to the transport,
+  // since any excess bytes were already copied to h->left_overs.
+  grpc_slice_buffer_reset_and_unref(h->args->read_buffer);
+  // Add auth context to channel args.
+  grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context);
+  grpc_channel_args* tmp_args = h->args->args;
+  h->args->args =
+      grpc_channel_args_copy_and_add(tmp_args, &auth_context_arg, 1);
+  grpc_channel_args_destroy(tmp_args);
+  // Invoke callback.
+  grpc_exec_ctx_sched(exec_ctx, h->on_handshake_done, GRPC_ERROR_NONE, NULL);
 done:
-  security_handshake_done_locked(exec_ctx, h, error);
   gpr_mu_unlock(&h->mu);
-  unref_handshake(h);
+  security_handshaker_unref(exec_ctx, h);
 }
 
 static grpc_error* check_peer_locked(grpc_exec_ctx *exec_ctx,
@@ -185,16 +221,15 @@ static grpc_error* send_handshake_bytes_to_peer_locked(grpc_exec_ctx *exec_ctx,
 }
 
 static void on_handshake_data_received_from_peer(grpc_exec_ctx *exec_ctx,
-                                                 void *handshake,
-                                                 grpc_error *error) {
-  security_handshaker *h = handshake;
+                                                 void *arg, grpc_error *error) {
+  security_handshaker *h = arg;
   gpr_mu_lock(&h->mu);
-  if (error != GRPC_ERROR_NONE) {
-    security_handshake_done_locked(
+  if (error != GRPC_ERROR_NONE || h->shutdown) {
+    security_handshake_failed_locked(
         exec_ctx, h,
         GRPC_ERROR_CREATE_REFERENCING("Handshake read failed", &error, 1));
     gpr_mu_unlock(&h->mu);
-    unref_handshake(h);
+    security_handshaker_unref(exec_ctx, h);
     return;
   }
   // Process received data.
@@ -217,21 +252,21 @@ static void on_handshake_data_received_from_peer(grpc_exec_ctx *exec_ctx,
     } else {
       error = send_handshake_bytes_to_peer_locked(exec_ctx, h);
       if (error != GRPC_ERROR_NONE) {
-        security_handshake_done_locked(exec_ctx, h, error);
+        security_handshake_failed_locked(exec_ctx, h, error);
         gpr_mu_unlock(&h->mu);
-        unref_handshake(h);
+        security_handshaker_unref(exec_ctx, h);
         return;
       }
       goto done;
     }
   }
   if (result != TSI_OK) {
-    security_handshake_done_locked(
+    security_handshake_failed_locked(
         exec_ctx, h,
         grpc_set_tsi_error_result(GRPC_ERROR_CREATE("Handshake failed"),
                                   result));
     gpr_mu_unlock(&h->mu);
-    unref_handshake(h);
+    security_handshaker_unref(exec_ctx, h);
     return;
   }
   /* Handshake is done and successful this point. */
@@ -258,42 +293,37 @@ static void on_handshake_data_received_from_peer(grpc_exec_ctx *exec_ctx,
   // Check peer.
   error = check_peer_locked(exec_ctx, h);
   if (error != GRPC_ERROR_NONE) {
-    security_handshake_done_locked(exec_ctx, h, error);
+    security_handshake_failed_locked(exec_ctx, h, error);
     gpr_mu_unlock(&h->mu);
-    unref_handshake(h);
+    security_handshaker_unref(exec_ctx, h);
     return;
   }
 done:
   gpr_mu_unlock(&h->mu);
 }
 
-/* If handshake is NULL, the handshake is done. */
 static void on_handshake_data_sent_to_peer(grpc_exec_ctx *exec_ctx,
-                                           void *handshake, grpc_error *error) {
-  security_handshaker *h = handshake;
-  /* Make sure that write is OK. */
-  if (error != GRPC_ERROR_NONE) {
-    if (handshake != NULL) {
-      gpr_mu_lock(&h->mu);
-      security_handshake_done_locked(
-          exec_ctx, h,
-          GRPC_ERROR_CREATE_REFERENCING("Handshake write failed", &error, 1));
-      gpr_mu_unlock(&h->mu);
-      unref_handshake(h);
-    }
+                                           void *arg, grpc_error *error) {
+  security_handshaker *h = arg;
+  gpr_mu_lock(&h->mu);
+  if (error != GRPC_ERROR_NONE || h->shutdown) {
+    security_handshake_failed_locked(
+        exec_ctx, h,
+        GRPC_ERROR_CREATE_REFERENCING("Handshake write failed", &error, 1));
+    gpr_mu_unlock(&h->mu);
+    security_handshaker_unref(exec_ctx, h);
     return;
   }
   /* We may be done. */
-  gpr_mu_lock(&h->mu);
   if (tsi_handshaker_is_in_progress(h->handshaker)) {
     grpc_endpoint_read(exec_ctx, h->args->endpoint, h->args->read_buffer,
                        &h->on_handshake_data_received_from_peer);
   } else {
     error = check_peer_locked(exec_ctx, h);
     if (error != GRPC_ERROR_NONE) {
-      security_handshake_done_locked(exec_ctx, h, error);
+      security_handshake_failed_locked(exec_ctx, h, error);
       gpr_mu_unlock(&h->mu);
-      unref_handshake(h);
+      security_handshaker_unref(exec_ctx, h);
       return;
     }
   }
@@ -305,17 +335,19 @@ static void on_handshake_data_sent_to_peer(grpc_exec_ctx *exec_ctx,
 //
 
 static void security_handshaker_destroy(grpc_exec_ctx* exec_ctx,
-                                            grpc_handshaker* handshaker) {
+                                        grpc_handshaker* handshaker) {
   security_handshaker* h = (security_handshaker*)handshaker;
-  unref_handshake(h);
+  security_handshaker_unref(exec_ctx, h);
 }
 
 static void security_handshaker_shutdown(grpc_exec_ctx* exec_ctx,
-                                             grpc_handshaker* handshaker) {
+                                         grpc_handshaker* handshaker) {
   security_handshaker *h = (security_handshaker*)handshaker;
   gpr_mu_lock(&h->mu);
-  if (h->args != NULL) {
+  if (!h->shutdown) {
+    h->shutdown = true;
     grpc_endpoint_shutdown(exec_ctx, h->args->endpoint);
+    cleanup_args_for_failure_locked(h);
   }
   gpr_mu_unlock(&h->mu);
 }
@@ -331,9 +363,9 @@ static void security_handshaker_do_handshake(
   gpr_ref(&h->refs);
   grpc_error* error = send_handshake_bytes_to_peer_locked(exec_ctx, h);
   if (error != GRPC_ERROR_NONE) {
-    security_handshake_done_locked(exec_ctx, h, error);
+    security_handshake_failed_locked(exec_ctx, h, error);
     gpr_mu_unlock(&h->mu);
-    unref_handshake(h);
+    security_handshaker_unref(exec_ctx, h);
     return;
   }
   gpr_mu_unlock(&h->mu);
-- 
GitLab