From cc4ef5919f81dc0c1be14172785da0f4d24d1e21 Mon Sep 17 00:00:00 2001
From: "Mark D. Roth" <roth@google.com>
Date: Thu, 29 Jun 2017 08:11:57 -0700
Subject: [PATCH] Improvements to C++ filter API: - Make sure all C-core
 parameters are passed into C++ methods. - Add Destroy() methods for
 ChannelData and CallData. - Use C++-style casts. - Add 'extern "C"' to
 iomgr/closure.h, which is used in C++ filters.

---
 src/core/lib/iomgr/closure.h           |  8 ++++
 src/cpp/common/channel_filter.h        | 54 +++++++++++++++-----------
 test/cpp/common/channel_filter_test.cc |  4 +-
 3 files changed, 42 insertions(+), 24 deletions(-)

diff --git a/src/core/lib/iomgr/closure.h b/src/core/lib/iomgr/closure.h
index 2560bf4527..cd32a4ba38 100644
--- a/src/core/lib/iomgr/closure.h
+++ b/src/core/lib/iomgr/closure.h
@@ -26,6 +26,10 @@
 #include "src/core/lib/iomgr/error.h"
 #include "src/core/lib/support/mpscq.h"
 
+#ifdef __cplusplus
+extern "C" {
+#endif
+
 struct grpc_closure;
 typedef struct grpc_closure grpc_closure;
 
@@ -197,4 +201,8 @@ void grpc_closure_list_sched(grpc_exec_ctx *exec_ctx,
   grpc_closure_list_sched(exec_ctx, closure_list)
 #endif
 
+#ifdef __cplusplus
+}
+#endif
+
 #endif /* GRPC_CORE_LIB_IOMGR_CLOSURE_H */
diff --git a/src/cpp/common/channel_filter.h b/src/cpp/common/channel_filter.h
index 1b6ace6b13..5d629f7c14 100644
--- a/src/cpp/common/channel_filter.h
+++ b/src/cpp/common/channel_filter.h
@@ -208,38 +208,45 @@ class TransportStreamOpBatch {
 /// Represents channel data.
 class ChannelData {
  public:
+  ChannelData() {}
   virtual ~ChannelData() {}
 
-  /// Initializes the call data.
-  virtual grpc_error *Init(grpc_exec_ctx *exec_ctx,
+  // TODO(roth): Come up with a more C++-like API for the channel element.
+
+  /// Initializes the channel data.
+  virtual grpc_error *Init(grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
                            grpc_channel_element_args *args) {
     return GRPC_ERROR_NONE;
   }
 
-  // TODO(roth): Find a way to avoid passing elem into these methods.
+  // Called before destruction.
+  virtual void Destroy(grpc_exec_ctx *exec_ctx, grpc_channel_element *elem) {}
 
   virtual void StartTransportOp(grpc_exec_ctx *exec_ctx,
                                 grpc_channel_element *elem, TransportOp *op);
 
   virtual void GetInfo(grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
                        const grpc_channel_info *channel_info);
-
- protected:
-  ChannelData() {}
 };
 
 /// Represents call data.
 class CallData {
  public:
+  CallData() {}
   virtual ~CallData() {}
 
+  // TODO(roth): Come up with a more C++-like API for the call element.
+
   /// Initializes the call data.
-  virtual grpc_error *Init(grpc_exec_ctx *exec_ctx, ChannelData *channel_data,
+  virtual grpc_error *Init(grpc_exec_ctx *exec_ctx, grpc_call_element *elem,
                            const grpc_call_element_args *args) {
     return GRPC_ERROR_NONE;
   }
 
-  // TODO(roth): Find a way to avoid passing elem into these methods.
+  // Called before destruction.
+  virtual void Destroy(grpc_exec_ctx *exec_ctx, grpc_call_element *elem,
+                       const grpc_call_final_info *final_info,
+                       grpc_closure *then_call_closure) {}
 
   /// Starts a new stream operation.
   virtual void StartTransportStreamOpBatch(grpc_exec_ctx *exec_ctx,
@@ -253,9 +260,6 @@ class CallData {
 
   /// Gets the peer name.
   virtual char *GetPeer(grpc_exec_ctx *exec_ctx, grpc_call_element *elem);
-
- protected:
-  CallData() {}
 };
 
 namespace internal {
@@ -271,19 +275,24 @@ class ChannelFilter final {
   static grpc_error *InitChannelElement(grpc_exec_ctx *exec_ctx,
                                         grpc_channel_element *elem,
                                         grpc_channel_element_args *args) {
+    // Construct the object in the already-allocated memory.
     ChannelDataType *channel_data = new (elem->channel_data) ChannelDataType();
-    return channel_data->Init(exec_ctx, args);
+    return channel_data->Init(exec_ctx, elem, args);
   }
 
   static void DestroyChannelElement(grpc_exec_ctx *exec_ctx,
                                     grpc_channel_element *elem) {
-    reinterpret_cast<ChannelDataType *>(elem->channel_data)->~ChannelDataType();
+    ChannelDataType *channel_data =
+        reinterpret_cast<ChannelDataType *>(elem->channel_data);
+    channel_data->Destroy(exec_ctx, elem);
+    channel_data->~ChannelDataType();
   }
 
   static void StartTransportOp(grpc_exec_ctx *exec_ctx,
                                grpc_channel_element *elem,
                                grpc_transport_op *op) {
-    ChannelDataType *channel_data = (ChannelDataType *)elem->channel_data;
+    ChannelDataType *channel_data =
+        reinterpret_cast<ChannelDataType *>(elem->channel_data);
     TransportOp op_wrapper(op);
     channel_data->StartTransportOp(exec_ctx, elem, &op_wrapper);
   }
@@ -291,7 +300,8 @@ class ChannelFilter final {
   static void GetChannelInfo(grpc_exec_ctx *exec_ctx,
                              grpc_channel_element *elem,
                              const grpc_channel_info *channel_info) {
-    ChannelDataType *channel_data = (ChannelDataType *)elem->channel_data;
+    ChannelDataType *channel_data =
+        reinterpret_cast<ChannelDataType *>(elem->channel_data);
     channel_data->GetInfo(exec_ctx, elem, channel_info);
   }
 
@@ -300,24 +310,24 @@ class ChannelFilter final {
   static grpc_error *InitCallElement(grpc_exec_ctx *exec_ctx,
                                      grpc_call_element *elem,
                                      const grpc_call_element_args *args) {
-    ChannelDataType *channel_data = (ChannelDataType *)elem->channel_data;
     // Construct the object in the already-allocated memory.
     CallDataType *call_data = new (elem->call_data) CallDataType();
-    return call_data->Init(exec_ctx, channel_data, args);
+    return call_data->Init(exec_ctx, elem, args);
   }
 
   static void DestroyCallElement(grpc_exec_ctx *exec_ctx,
                                  grpc_call_element *elem,
                                  const grpc_call_final_info *final_info,
                                  grpc_closure *then_call_closure) {
-    GPR_ASSERT(then_call_closure == NULL);
-    reinterpret_cast<CallDataType *>(elem->call_data)->~CallDataType();
+    CallDataType *call_data = reinterpret_cast<CallDataType *>(elem->call_data);
+    call_data->Destroy(exec_ctx, elem, final_info, then_call_closure);
+    call_data->~CallDataType();
   }
 
   static void StartTransportStreamOpBatch(grpc_exec_ctx *exec_ctx,
                                           grpc_call_element *elem,
                                           grpc_transport_stream_op_batch *op) {
-    CallDataType *call_data = (CallDataType *)elem->call_data;
+    CallDataType *call_data = reinterpret_cast<CallDataType *>(elem->call_data);
     TransportStreamOpBatch op_wrapper(op);
     call_data->StartTransportStreamOpBatch(exec_ctx, elem, &op_wrapper);
   }
@@ -325,12 +335,12 @@ class ChannelFilter final {
   static void SetPollsetOrPollsetSet(grpc_exec_ctx *exec_ctx,
                                      grpc_call_element *elem,
                                      grpc_polling_entity *pollent) {
-    CallDataType *call_data = (CallDataType *)elem->call_data;
+    CallDataType *call_data = reinterpret_cast<CallDataType *>(elem->call_data);
     call_data->SetPollsetOrPollsetSet(exec_ctx, elem, pollent);
   }
 
   static char *GetPeer(grpc_exec_ctx *exec_ctx, grpc_call_element *elem) {
-    CallDataType *call_data = (CallDataType *)elem->call_data;
+    CallDataType *call_data = reinterpret_cast<CallDataType *>(elem->call_data);
     return call_data->GetPeer(exec_ctx, elem);
   }
 };
diff --git a/test/cpp/common/channel_filter_test.cc b/test/cpp/common/channel_filter_test.cc
index e747e633a0..638518107b 100644
--- a/test/cpp/common/channel_filter_test.cc
+++ b/test/cpp/common/channel_filter_test.cc
@@ -28,7 +28,7 @@ class MyChannelData : public ChannelData {
  public:
   MyChannelData() {}
 
-  grpc_error* Init(grpc_exec_ctx* exec_ctx,
+  grpc_error* Init(grpc_exec_ctx* exec_ctx, grpc_channel_element* elem,
                    grpc_channel_element_args* args) override {
     (void)args->channel_args;  // Make sure field is available.
     return GRPC_ERROR_NONE;
@@ -39,7 +39,7 @@ class MyCallData : public CallData {
  public:
   MyCallData() {}
 
-  grpc_error* Init(grpc_exec_ctx* exec_ctx, ChannelData* channel_data,
+  grpc_error* Init(grpc_exec_ctx* exec_ctx, grpc_call_element* elem,
                    const grpc_call_element_args* args) override {
     (void)args->path;  // Make sure field is available.
     return GRPC_ERROR_NONE;
-- 
GitLab