From 93cc06a48441a43c97fe7136839f70eab6355e99 Mon Sep 17 00:00:00 2001
From: Masood Malekghassemi <atash@google.com>
Date: Fri, 13 May 2016 14:25:35 -0700
Subject: [PATCH] Add compression support to Cython layers

---
 src/python/grpcio/grpc/_adapter/_low.py       | 20 +++--
 src/python/grpcio/grpc/_adapter/_types.py     |  4 +-
 .../grpcio/grpc/_cython/_cygrpc/grpc.pxi      | 38 ++++++++++
 .../grpc/_cython/_cygrpc/records.pxd.pxi      |  4 +
 .../grpc/_cython/_cygrpc/records.pyx.pxi      | 74 +++++++++++++++++--
 .../grpcio/grpc/_cython/imports.generated.h   |  9 ++-
 .../grpcio/tests/unit/_cython/cygrpc_test.py  | 58 +++++++++------
 .../grpc/_cython/imports.generated.h.template |  9 ++-
 8 files changed, 167 insertions(+), 49 deletions(-)

diff --git a/src/python/grpcio/grpc/_adapter/_low.py b/src/python/grpcio/grpc/_adapter/_low.py
index b13d8dd9dd..00788bd4cf 100644
--- a/src/python/grpcio/grpc/_adapter/_low.py
+++ b/src/python/grpcio/grpc/_adapter/_low.py
@@ -195,26 +195,30 @@ class Call(_types.Call):
         translated_op = cygrpc.operation_send_initial_metadata(
             cygrpc.Metadata(
                 cygrpc.Metadatum(key, value)
-                for key, value in op.initial_metadata))
+                for key, value in op.initial_metadata),
+            op.flags)
       elif op.type == _types.OpType.SEND_MESSAGE:
-        translated_op = cygrpc.operation_send_message(op.message)
+        translated_op = cygrpc.operation_send_message(op.message, op.flags)
       elif op.type == _types.OpType.SEND_CLOSE_FROM_CLIENT:
-        translated_op = cygrpc.operation_send_close_from_client()
+        translated_op = cygrpc.operation_send_close_from_client(op.flags)
       elif op.type == _types.OpType.SEND_STATUS_FROM_SERVER:
         translated_op = cygrpc.operation_send_status_from_server(
             cygrpc.Metadata(
                 cygrpc.Metadatum(key, value)
                 for key, value in op.trailing_metadata),
             op.status.code,
-            op.status.details)
+            op.status.details,
+            op.flags)
       elif op.type == _types.OpType.RECV_INITIAL_METADATA:
-        translated_op = cygrpc.operation_receive_initial_metadata()
+        translated_op = cygrpc.operation_receive_initial_metadata(
+            op.flags)
       elif op.type == _types.OpType.RECV_MESSAGE:
-        translated_op = cygrpc.operation_receive_message()
+        translated_op = cygrpc.operation_receive_message(op.flags)
       elif op.type == _types.OpType.RECV_STATUS_ON_CLIENT:
-        translated_op = cygrpc.operation_receive_status_on_client()
+        translated_op = cygrpc.operation_receive_status_on_client(
+            op.flags)
       elif op.type == _types.OpType.RECV_CLOSE_ON_SERVER:
-        translated_op = cygrpc.operation_receive_close_on_server()
+        translated_op = cygrpc.operation_receive_close_on_server(op.flags)
       else:
         raise ValueError('unexpected operation type {}'.format(op.type))
       translated_ops.append(translated_op)
diff --git a/src/python/grpcio/grpc/_adapter/_types.py b/src/python/grpcio/grpc/_adapter/_types.py
index 8ca7ff4b60..f8405949d4 100644
--- a/src/python/grpcio/grpc/_adapter/_types.py
+++ b/src/python/grpcio/grpc/_adapter/_types.py
@@ -152,7 +152,7 @@ class OpArgs(collections.namedtuple(
         'trailing_metadata',
         'message',
         'status',
-        'write_flags',
+        'flags',
     ])):
   """Arguments passed into a GRPC operation.
 
@@ -165,7 +165,7 @@ class OpArgs(collections.namedtuple(
     message (bytes): Only valid if type == OpType.SEND_MESSAGE, else is None.
     status (Status): Only valid if type == OpType.SEND_STATUS_FROM_SERVER, else
       is None.
-    write_flags (int): a bit OR'ing of 0 or more OpWriteFlags values.
+    flags (int): a bitwise OR'ing of 0 or more OpWriteFlags values.
   """
 
   @staticmethod
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
index 3d158a7707..66e6e6b549 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
@@ -140,6 +140,9 @@ cdef extern from "grpc/_cython/loader.h":
   const char *GRPC_ARG_PRIMARY_USER_AGENT_STRING
   const char *GRPC_ARG_SECONDARY_USER_AGENT_STRING
   const char *GRPC_SSL_TARGET_NAME_OVERRIDE_ARG
+  const char *GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM
+  const char *GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL
+  const char *GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET
 
   const int GRPC_WRITE_BUFFER_HINT
   const int GRPC_WRITE_NO_COMPRESS
@@ -425,3 +428,38 @@ cdef extern from "grpc/_cython/loader.h":
 
   grpc_call_credentials *grpc_metadata_credentials_create_from_plugin(
       grpc_metadata_credentials_plugin plugin, void *reserved) nogil
+
+  ctypedef enum grpc_compression_algorithm:
+    GRPC_COMPRESS_NONE
+    GRPC_COMPRESS_DEFLATE
+    GRPC_COMPRESS_GZIP
+    GRPC_COMPRESS_ALGORITHMS_COUNT
+
+  ctypedef enum grpc_compression_level:
+    GRPC_COMPRESS_LEVEL_NONE
+    GRPC_COMPRESS_LEVEL_LOW
+    GRPC_COMPRESS_LEVEL_MED
+    GRPC_COMPRESS_LEVEL_HIGH
+    GRPC_COMPRESS_LEVEL_COUNT
+
+  ctypedef struct grpc_compression_options:
+    uint32_t enabled_algorithms_bitset
+    grpc_compression_algorithm default_compression_algorithm
+
+  int grpc_compression_algorithm_parse(
+      const char *name, size_t name_length,
+      grpc_compression_algorithm *algorithm) nogil
+  int grpc_compression_algorithm_name(grpc_compression_algorithm algorithm,
+                                      char **name) nogil
+  grpc_compression_algorithm grpc_compression_algorithm_for_level(
+      grpc_compression_level level, uint32_t accepted_encodings) nogil
+  void grpc_compression_options_init(grpc_compression_options *opts) nogil
+  void grpc_compression_options_enable_algorithm(
+      grpc_compression_options *opts,
+      grpc_compression_algorithm algorithm) nogil
+  void grpc_compression_options_disable_algorithm(
+      grpc_compression_options *opts,
+      grpc_compression_algorithm algorithm) nogil
+  int grpc_compression_options_is_algorithm_enabled(
+      const grpc_compression_options *opts,
+      grpc_compression_algorithm algorithm) nogil
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi
index 30397818a1..0474697af8 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi
@@ -124,3 +124,7 @@ cdef class Operations:
   cdef size_t c_nops
   cdef list operations
 
+
+cdef class CompressionOptions:
+
+  cdef grpc_compression_options c_options
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
index c2202bdab2..c7539f0d49 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
@@ -103,6 +103,19 @@ class OperationType:
   receive_close_on_server = GRPC_OP_RECV_CLOSE_ON_SERVER
 
 
+class CompressionAlgorithm:
+  none = GRPC_COMPRESS_NONE
+  deflate = GRPC_COMPRESS_DEFLATE
+  gzip = GRPC_COMPRESS_GZIP
+
+
+class CompressionLevel:
+  none = GRPC_COMPRESS_LEVEL_NONE
+  low = GRPC_COMPRESS_LEVEL_LOW
+  medium = GRPC_COMPRESS_LEVEL_MED
+  high = GRPC_COMPRESS_LEVEL_HIGH
+
+
 cdef class Timespec:
 
   def __cinit__(self, time):
@@ -472,6 +485,10 @@ cdef class Operation:
   def type(self):
     return self.c_op.type
 
+  @property
+  def flags(self):
+    return self.c_op.flags
+
   @property
   def has_status(self):
     return self.c_op.type == GRPC_OP_RECV_STATUS_ON_CLIENT
@@ -553,9 +570,10 @@ cdef class Operation:
       with nogil:
         gpr_free(self._received_status_details)
 
-def operation_send_initial_metadata(Metadata metadata):
+def operation_send_initial_metadata(Metadata metadata, int flags):
   cdef Operation op = Operation()
   op.c_op.type = GRPC_OP_SEND_INITIAL_METADATA
+  op.c_op.flags = flags
   op.c_op.data.send_initial_metadata.count = metadata.c_metadata_array.count
   op.c_op.data.send_initial_metadata.metadata = (
       metadata.c_metadata_array.metadata)
@@ -563,23 +581,25 @@ def operation_send_initial_metadata(Metadata metadata):
   op.is_valid = True
   return op
 
-def operation_send_message(data):
+def operation_send_message(data, int flags):
   cdef Operation op = Operation()
   op.c_op.type = GRPC_OP_SEND_MESSAGE
+  op.c_op.flags = flags
   byte_buffer = ByteBuffer(data)
   op.c_op.data.send_message = byte_buffer.c_byte_buffer
   op.references.append(byte_buffer)
   op.is_valid = True
   return op
 
-def operation_send_close_from_client():
+def operation_send_close_from_client(int flags):
   cdef Operation op = Operation()
   op.c_op.type = GRPC_OP_SEND_CLOSE_FROM_CLIENT
+  op.c_op.flags = flags
   op.is_valid = True
   return op
 
 def operation_send_status_from_server(
-    Metadata metadata, grpc_status_code code, details):
+    Metadata metadata, grpc_status_code code, details, int flags):
   if isinstance(details, bytes):
     pass
   elif isinstance(details, basestring):
@@ -588,6 +608,7 @@ def operation_send_status_from_server(
     raise TypeError("expected a str or bytes object for details")
   cdef Operation op = Operation()
   op.c_op.type = GRPC_OP_SEND_STATUS_FROM_SERVER
+  op.c_op.flags = flags
   op.c_op.data.send_status_from_server.trailing_metadata_count = (
       metadata.c_metadata_array.count)
   op.c_op.data.send_status_from_server.trailing_metadata = (
@@ -599,18 +620,20 @@ def operation_send_status_from_server(
   op.is_valid = True
   return op
 
-def operation_receive_initial_metadata():
+def operation_receive_initial_metadata(int flags):
   cdef Operation op = Operation()
   op.c_op.type = GRPC_OP_RECV_INITIAL_METADATA
+  op.c_op.flags = flags
   op._received_metadata = Metadata([])
   op.c_op.data.receive_initial_metadata = (
       &op._received_metadata.c_metadata_array)
   op.is_valid = True
   return op
 
-def operation_receive_message():
+def operation_receive_message(int flags):
   cdef Operation op = Operation()
   op.c_op.type = GRPC_OP_RECV_MESSAGE
+  op.c_op.flags = flags
   op._received_message = ByteBuffer(None)
   # n.b. the c_op.data.receive_message field needs to be deleted by us,
   # anyway, so we just let that be handled by the ByteBuffer() we allocated
@@ -619,9 +642,10 @@ def operation_receive_message():
   op.is_valid = True
   return op
 
-def operation_receive_status_on_client():
+def operation_receive_status_on_client(int flags):
   cdef Operation op = Operation()
   op.c_op.type = GRPC_OP_RECV_STATUS_ON_CLIENT
+  op.c_op.flags = flags
   op._received_metadata = Metadata([])
   op.c_op.data.receive_status_on_client.trailing_metadata = (
       &op._received_metadata.c_metadata_array)
@@ -634,9 +658,10 @@ def operation_receive_status_on_client():
   op.is_valid = True
   return op
 
-def operation_receive_close_on_server():
+def operation_receive_close_on_server(int flags):
   cdef Operation op = Operation()
   op.c_op.type = GRPC_OP_RECV_CLOSE_ON_SERVER
+  op.c_op.flags = flags
   op.c_op.data.receive_close_on_server.cancelled = &op._received_cancelled
   op.is_valid = True
   return op
@@ -692,3 +717,36 @@ cdef class Operations:
   def __iter__(self):
     return _OperationsIterator(self)
 
+
+cdef class CompressionOptions:
+
+  def __cinit__(self):
+    with nogil:
+      grpc_compression_options_init(&self.c_options)
+
+  def enable_algorithm(self, grpc_compression_algorithm algorithm):
+    with nogil:
+      grpc_compression_options_enable_algorithm(&self.c_options, algorithm)
+
+  def disable_algorithm(self, grpc_compression_algorithm algorithm):
+    with nogil:
+      grpc_compression_options_disable_algorithm(&self.c_options, algorithm)
+
+  def is_algorithm_enabled(self, grpc_compression_algorithm algorithm):
+    cdef int result
+    with nogil:
+      result = grpc_compression_options_is_algorithm_enabled(
+          &self.c_options, algorithm)
+    return result
+
+  def to_channel_arg(self):
+    return ChannelArg(GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET,
+                      self.c_options.enabled_algorithms_bitset)
+
+
+def compression_algorithm_name(grpc_compression_algorithm algorithm):
+  cdef char* name
+  with nogil:
+    grpc_compression_algorithm_name(algorithm, &name)
+  # Let Cython do the right thing with string casting
+  return name
diff --git a/src/python/grpcio/grpc/_cython/imports.generated.h b/src/python/grpcio/grpc/_cython/imports.generated.h
index 54c8aaad13..6de295414a 100644
--- a/src/python/grpcio/grpc/_cython/imports.generated.h
+++ b/src/python/grpcio/grpc/_cython/imports.generated.h
@@ -870,14 +870,15 @@ void pygrpc_load_imports(HMODULE library);
 
 #else /* !GPR_WIN32 */
 
-#include <grpc/support/alloc.h>
-#include <grpc/support/slice.h>
-#include <grpc/support/time.h>
-#include <grpc/status.h>
 #include <grpc/byte_buffer.h>
 #include <grpc/byte_buffer_reader.h>
+#include <grpc/compression.h>
 #include <grpc/grpc.h>
 #include <grpc/grpc_security.h>
+#include <grpc/support/alloc.h>
+#include <grpc/support/slice.h>
+#include <grpc/support/time.h>
+#include <grpc/status.h>
 
 #endif /* !GPR_WIN32 */
 
diff --git a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py
index 876da88de9..0a511101f0 100644
--- a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py
+++ b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py
@@ -40,6 +40,7 @@ from tests.unit import resources
 _SSL_HOST_OVERRIDE = 'foo.test.google.fr'
 _CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key'
 _CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value'
+_EMPTY_FLAGS = 0
 
 def _metadata_plugin_callback(context, callback):
   callback(cygrpc.Metadata(
@@ -76,7 +77,7 @@ class TypeSmokeTest(unittest.TestCase):
 
   def testOperationsIteration(self):
     operations = cygrpc.Operations([
-        cygrpc.operation_send_message('asdf')])
+        cygrpc.operation_send_message('asdf', _EMPTY_FLAGS)])
     iterator = iter(operations)
     operation = next(iterator)
     self.assertIsInstance(operation, cygrpc.Operation)
@@ -85,6 +86,11 @@ class TypeSmokeTest(unittest.TestCase):
     with self.assertRaises(StopIteration):
       next(iterator)
 
+  def testOperationFlags(self):
+    operation = cygrpc.operation_send_message('asdf',
+                                              cygrpc.WriteFlag.no_compress)
+    self.assertEqual(cygrpc.WriteFlag.no_compress, operation.flags)
+
   def testTimespec(self):
     now = time.time()
     timespec = cygrpc.Timespec(now)
@@ -188,12 +194,13 @@ class InsecureServerInsecureClient(unittest.TestCase):
                          CLIENT_METADATA_ASCII_VALUE),
         cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)])
     client_start_batch_result = client_call.start_batch(cygrpc.Operations([
-        cygrpc.operation_send_initial_metadata(client_initial_metadata),
-        cygrpc.operation_send_message(REQUEST),
-        cygrpc.operation_send_close_from_client(),
-        cygrpc.operation_receive_initial_metadata(),
-        cygrpc.operation_receive_message(),
-        cygrpc.operation_receive_status_on_client()
+        cygrpc.operation_send_initial_metadata(client_initial_metadata,
+                                               _EMPTY_FLAGS),
+        cygrpc.operation_send_message(REQUEST, _EMPTY_FLAGS),
+        cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+        cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
+        cygrpc.operation_receive_message(_EMPTY_FLAGS),
+        cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
     ]), client_call_tag)
     self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
     client_event_future = test_utilities.CompletionQueuePollFuture(
@@ -223,12 +230,14 @@ class InsecureServerInsecureClient(unittest.TestCase):
         cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
                          SERVER_TRAILING_METADATA_VALUE)])
     server_start_batch_result = server_call.start_batch([
-        cygrpc.operation_send_initial_metadata(server_initial_metadata),
-        cygrpc.operation_receive_message(),
-        cygrpc.operation_send_message(RESPONSE),
-        cygrpc.operation_receive_close_on_server(),
+        cygrpc.operation_send_initial_metadata(server_initial_metadata,
+                                               _EMPTY_FLAGS),
+        cygrpc.operation_receive_message(_EMPTY_FLAGS),
+        cygrpc.operation_send_message(RESPONSE, _EMPTY_FLAGS),
+        cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
         cygrpc.operation_send_status_from_server(
-            server_trailing_metadata, SERVER_STATUS_CODE, SERVER_STATUS_DETAILS)
+            server_trailing_metadata, SERVER_STATUS_CODE,
+            SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
     ], server_call_tag)
     self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
 
@@ -349,12 +358,13 @@ class SecureServerSecureClient(unittest.TestCase):
                          CLIENT_METADATA_ASCII_VALUE),
         cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)])
     client_start_batch_result = client_call.start_batch(cygrpc.Operations([
-        cygrpc.operation_send_initial_metadata(client_initial_metadata),
-        cygrpc.operation_send_message(REQUEST),
-        cygrpc.operation_send_close_from_client(),
-        cygrpc.operation_receive_initial_metadata(),
-        cygrpc.operation_receive_message(),
-        cygrpc.operation_receive_status_on_client()
+        cygrpc.operation_send_initial_metadata(client_initial_metadata,
+                                               _EMPTY_FLAGS),
+        cygrpc.operation_send_message(REQUEST, _EMPTY_FLAGS),
+        cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+        cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
+        cygrpc.operation_receive_message(_EMPTY_FLAGS),
+        cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
     ]), client_call_tag)
     self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
     client_event_future = test_utilities.CompletionQueuePollFuture(
@@ -387,12 +397,14 @@ class SecureServerSecureClient(unittest.TestCase):
         cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
                          SERVER_TRAILING_METADATA_VALUE)])
     server_start_batch_result = server_call.start_batch([
-        cygrpc.operation_send_initial_metadata(server_initial_metadata),
-        cygrpc.operation_receive_message(),
-        cygrpc.operation_send_message(RESPONSE),
-        cygrpc.operation_receive_close_on_server(),
+        cygrpc.operation_send_initial_metadata(server_initial_metadata,
+                                               _EMPTY_FLAGS),
+        cygrpc.operation_receive_message(_EMPTY_FLAGS),
+        cygrpc.operation_send_message(RESPONSE, _EMPTY_FLAGS),
+        cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
         cygrpc.operation_send_status_from_server(
-            server_trailing_metadata, SERVER_STATUS_CODE, SERVER_STATUS_DETAILS)
+            server_trailing_metadata, SERVER_STATUS_CODE,
+            SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
     ], server_call_tag)
     self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
 
diff --git a/templates/src/python/grpcio/grpc/_cython/imports.generated.h.template b/templates/src/python/grpcio/grpc/_cython/imports.generated.h.template
index 8e7c183180..26e717e58d 100644
--- a/templates/src/python/grpcio/grpc/_cython/imports.generated.h.template
+++ b/templates/src/python/grpcio/grpc/_cython/imports.generated.h.template
@@ -64,14 +64,15 @@
 
   #else /* !GPR_WIN32 */
 
-  #include <grpc/support/alloc.h>
-  #include <grpc/support/slice.h>
-  #include <grpc/support/time.h>
-  #include <grpc/status.h>
   #include <grpc/byte_buffer.h>
   #include <grpc/byte_buffer_reader.h>
+  #include <grpc/compression.h>
   #include <grpc/grpc.h>
   #include <grpc/grpc_security.h>
+  #include <grpc/support/alloc.h>
+  #include <grpc/support/slice.h>
+  #include <grpc/support/time.h>
+  #include <grpc/status.h>
 
   #endif /* !GPR_WIN32 */
 
-- 
GitLab