diff --git a/include/grpc++/client_context.h b/include/grpc++/client_context.h index 5e10875260474ba13e33a7efa07a469ad64a61cf..88954e227b5c921e00e48081c09dbac67b60cd15 100644 --- a/include/grpc++/client_context.h +++ b/include/grpc++/client_context.h @@ -38,6 +38,7 @@ #include <memory> #include <string> +#include <grpc/compression.h> #include <grpc/support/log.h> #include <grpc/support/time.h> #include <grpc++/config.h> @@ -107,6 +108,17 @@ class ClientContext { creds_ = creds; } + grpc_compression_level get_compression_level() const { + return compression_level_; + } + void set_compression_level(grpc_compression_level level); + + grpc_compression_algorithm get_compression_algorithm() const { + return compression_algorithm_; + } + void set_compression_algorithm(grpc_compression_algorithm algorithm); + + void TryCancel(); private: @@ -157,6 +169,9 @@ class ClientContext { std::multimap<grpc::string, grpc::string> send_initial_metadata_; std::multimap<grpc::string, grpc::string> recv_initial_metadata_; std::multimap<grpc::string, grpc::string> trailing_metadata_; + + grpc_compression_level compression_level_; + grpc_compression_algorithm compression_algorithm_; }; } // namespace grpc diff --git a/include/grpc++/server_context.h b/include/grpc++/server_context.h index 326b6a125ce4110ac4dcb9efd893ee5e6c0f9ec4..a2f0a2f990a48b2e0d376c6fa706953f6669532a 100644 --- a/include/grpc++/server_context.h +++ b/include/grpc++/server_context.h @@ -36,6 +36,7 @@ #include <map> +#include <grpc/compression.h> #include <grpc/support/time.h> #include <grpc++/config.h> #include <grpc++/time.h> @@ -97,6 +98,16 @@ class ServerContext { return client_metadata_; } + grpc_compression_level get_compression_level() const { + return compression_level_; + } + void set_compression_level(grpc_compression_level level); + + grpc_compression_algorithm get_compression_algorithm() const { + return compression_algorithm_; + } + void set_compression_algorithm(grpc_compression_algorithm algorithm); + private: friend class ::grpc::Server; template <class W, class R> @@ -142,6 +153,9 @@ class ServerContext { std::multimap<grpc::string, grpc::string> client_metadata_; std::multimap<grpc::string, grpc::string> initial_metadata_; std::multimap<grpc::string, grpc::string> trailing_metadata_; + + grpc_compression_level compression_level_; + grpc_compression_algorithm compression_algorithm_; }; } // namespace grpc diff --git a/include/grpc/compression.h b/include/grpc/compression.h index 1cff5d2d7e1be756005e70b6e75b07abd8fd7b81..dd7e1d0a125bd9e0e05b525abc44be981e5f5681 100644 --- a/include/grpc/compression.h +++ b/include/grpc/compression.h @@ -34,6 +34,10 @@ #ifndef GRPC_COMPRESSION_H #define GRPC_COMPRESSION_H +#ifdef __cplusplus +extern "C" { +#endif + /** To be used in channel arguments */ #define GRPC_COMPRESSION_LEVEL_ARG "grpc.compression_level" @@ -76,4 +80,8 @@ grpc_compression_level grpc_compression_level_for_algorithm( grpc_compression_algorithm grpc_compression_algorithm_for_level( grpc_compression_level level); +#ifdef __cplusplus +} +#endif + #endif /* GRPC_COMPRESSION_H */ diff --git a/src/core/channel/compress_filter.c b/src/core/channel/compress_filter.c index f5fe87d6b83d226f5b70bd2b13a8a92ba7cf7360..6100a90668e5bb14ed153dfbade1aa64f9119572 100644 --- a/src/core/channel/compress_filter.c +++ b/src/core/channel/compress_filter.c @@ -50,7 +50,8 @@ typedef struct call_data { } call_data; typedef struct channel_data { - grpc_mdstr *mdstr_compression_algorithm_key; + grpc_mdstr *mdstr_request_compression_algorithm_key; + grpc_mdstr *mdstr_outgoing_compression_algorithm_key; grpc_mdelem *mdelem_compression_algorithms[GRPC_COMPRESS_ALGORITHMS_COUNT]; grpc_compression_algorithm default_compression_algorithm; } channel_data; @@ -72,14 +73,14 @@ static int compress_send_sb(grpc_compression_algorithm algorithm, } /** For each \a md element from the incoming metadata, filter out the entry for - * "grpc-compression-algorithm", using its value to populate the call data's + * "grpc-encoding", using its value to populate the call data's * compression_algorithm field. */ static grpc_mdelem* compression_md_filter(void *user_data, grpc_mdelem *md) { grpc_call_element *elem = user_data; call_data *calld = elem->call_data; channel_data *channeld = elem->channel_data; - if (md->key == channeld->mdstr_compression_algorithm_key) { + if (md->key == channeld->mdstr_request_compression_algorithm_key) { const char *md_c_str = grpc_mdstr_as_c_string(md->value); if (!grpc_compression_algorithm_parse(md_c_str, &calld->compression_algorithm)) { @@ -184,7 +185,6 @@ static void process_send_ops(grpc_call_element *elem, break; case GRPC_OP_SLICE: if (did_compress) { - gpr_slice_unref(sop->data.slice); if (j < calld->slices.count) { sop->data.slice = gpr_slice_ref(calld->slices.slices[j++]); } @@ -259,7 +259,10 @@ static void init_channel_elem(grpc_channel_element *elem, channeld->default_compression_algorithm = grpc_compression_algorithm_for_level(clevel); - channeld->mdstr_compression_algorithm_key = + channeld->mdstr_request_compression_algorithm_key = + grpc_mdstr_from_string(mdctx, GRPC_COMPRESS_REQUEST_ALGORITHM_KEY); + + channeld->mdstr_outgoing_compression_algorithm_key = grpc_mdstr_from_string(mdctx, "grpc-encoding"); for (algo_idx = 0; algo_idx < GRPC_COMPRESS_ALGORITHMS_COUNT; ++algo_idx) { @@ -267,7 +270,8 @@ static void init_channel_elem(grpc_channel_element *elem, GPR_ASSERT(grpc_compression_algorithm_name(algo_idx, &algorith_name) != 0); channeld->mdelem_compression_algorithms[algo_idx] = grpc_mdelem_from_metadata_strings( - mdctx, grpc_mdstr_ref(channeld->mdstr_compression_algorithm_key), + mdctx, + grpc_mdstr_ref(channeld->mdstr_outgoing_compression_algorithm_key), grpc_mdstr_from_string(mdctx, algorith_name)); } @@ -283,7 +287,8 @@ static void destroy_channel_elem(grpc_channel_element *elem) { channel_data *channeld = elem->channel_data; grpc_compression_algorithm algo_idx; - grpc_mdstr_unref(channeld->mdstr_compression_algorithm_key); + grpc_mdstr_unref(channeld->mdstr_request_compression_algorithm_key); + grpc_mdstr_unref(channeld->mdstr_outgoing_compression_algorithm_key); for (algo_idx = 0; algo_idx < GRPC_COMPRESS_ALGORITHMS_COUNT; ++algo_idx) { grpc_mdelem_unref(channeld->mdelem_compression_algorithms[algo_idx]); diff --git a/src/core/channel/compress_filter.h b/src/core/channel/compress_filter.h index ea667969e15ef365ceb8c60fab155d2797665dd2..3a196eb7bf0e3a3abd490a3a203c823d58279b3a 100644 --- a/src/core/channel/compress_filter.h +++ b/src/core/channel/compress_filter.h @@ -36,6 +36,8 @@ #include "src/core/channel/channel_stack.h" +#define GRPC_COMPRESS_REQUEST_ALGORITHM_KEY "internal:grpc-encoding-request" + /** Message-level compression filter. * * See <grpc/compression.h> for the available compression levels. diff --git a/src/core/surface/call.c b/src/core/surface/call.c index 37dadecb35bd669f4fdfe4b730148841b9d6c885..5f489c0f4ebbcd29fbe32d18b2ad29cd6c928c3c 100644 --- a/src/core/surface/call.c +++ b/src/core/surface/call.c @@ -1243,7 +1243,7 @@ static void recv_metadata(grpc_call *call, grpc_metadata_batch *md) { } else if (key == grpc_channel_get_message_string(call->channel)) { set_status_details(call, STATUS_FROM_WIRE, grpc_mdstr_ref(md->value)); } else if (key == - grpc_channel_get_compresssion_algorithm_string(call->channel)) { + grpc_channel_get_compression_algorithm_string(call->channel)) { set_compression_algorithm(call, decode_compression(md)); } else { dest = &call->buffered_metadata[is_trailing]; diff --git a/src/core/surface/channel.c b/src/core/surface/channel.c index d3dcb2255f8ea48881ea11b55f75348cc1b7bb04..cab99e71d38ae7118c806312d856d3ccdd49b46d 100644 --- a/src/core/surface/channel.c +++ b/src/core/surface/channel.c @@ -273,7 +273,7 @@ grpc_mdstr *grpc_channel_get_status_string(grpc_channel *channel) { return channel->grpc_status_string; } -grpc_mdstr *grpc_channel_get_compresssion_algorithm_string( +grpc_mdstr *grpc_channel_get_compression_algorithm_string( grpc_channel *channel) { return channel->grpc_compression_algorithm_string; } diff --git a/src/core/surface/channel.h b/src/core/surface/channel.h index 8d0fe812ce084061a9e9a243cb91ec208e424f4e..66924ad72c755dcf44e93f9852ae3b75b06f3be8 100644 --- a/src/core/surface/channel.h +++ b/src/core/surface/channel.h @@ -53,7 +53,7 @@ grpc_mdctx *grpc_channel_get_metadata_context(grpc_channel *channel); grpc_mdelem *grpc_channel_get_reffed_status_elem(grpc_channel *channel, int status_code); grpc_mdstr *grpc_channel_get_status_string(grpc_channel *channel); -grpc_mdstr *grpc_channel_get_compresssion_algorithm_string( +grpc_mdstr *grpc_channel_get_compression_algorithm_string( grpc_channel *channel); grpc_mdstr *grpc_channel_get_message_string(grpc_channel *channel); gpr_uint32 grpc_channel_get_max_message_length(grpc_channel *channel); diff --git a/src/core/surface/secure_channel_create.c b/src/core/surface/secure_channel_create.c index be46c544274620ad88a53a29d0f398a79f096d43..cfa869ec716fd0e9912e3101338c135d5a35b9f4 100644 --- a/src/core/surface/secure_channel_create.c +++ b/src/core/surface/secure_channel_create.c @@ -244,10 +244,7 @@ grpc_channel *grpc_secure_channel_create(grpc_credentials *creds, if (grpc_channel_args_is_census_enabled(args)) { filters[n++] = &grpc_client_census_filter; } */ - if (grpc_channel_args_get_compression_level(args) > - GRPC_COMPRESS_LEVEL_NONE) { - filters[n++] = &grpc_compress_filter; - } + filters[n++] = &grpc_compress_filter; filters[n++] = &grpc_client_channel_filter; GPR_ASSERT(n <= MAX_FILTERS); channel = grpc_channel_create_from_filters(filters, n, args_copy, mdctx, 1); diff --git a/src/cpp/client/client_context.cc b/src/cpp/client/client_context.cc index 72cdd49d195f99616f865c40b1b94ff1c819ab35..0eba554e334c4e9f6311b06c362a272215599727 100644 --- a/src/cpp/client/client_context.cc +++ b/src/cpp/client/client_context.cc @@ -34,9 +34,12 @@ #include <grpc++/client_context.h> #include <grpc/grpc.h> +#include <grpc/support/string_util.h> #include <grpc++/credentials.h> #include <grpc++/time.h> +#include "src/core/channel/compress_filter.h" + namespace grpc { ClientContext::ClientContext() @@ -75,6 +78,24 @@ void ClientContext::set_call(grpc_call* call, } } +void ClientContext::set_compression_level(grpc_compression_level level) { + const grpc_compression_algorithm algorithm_for_level = + grpc_compression_algorithm_for_level(level); + set_compression_algorithm(algorithm_for_level); +} + +void ClientContext::set_compression_algorithm( + grpc_compression_algorithm algorithm) { + char* algorithm_name = NULL; + if (!grpc_compression_algorithm_name(algorithm, &algorithm_name)) { + gpr_log(GPR_ERROR, "Name for compression algorithm '%d' unknown.", + algorithm); + abort(); + } + GPR_ASSERT(algorithm_name != NULL); + AddMetadata(GRPC_COMPRESS_REQUEST_ALGORITHM_KEY, algorithm_name); +} + void ClientContext::TryCancel() { if (call_) { grpc_call_cancel(call_); diff --git a/src/cpp/proto/proto_utils.cc b/src/cpp/proto/proto_utils.cc index 268e4f6d1fe92735ed2895ecd2c2d396df3e89d2..337f6801292fdd93a7c51dcb13d9fab8f39b6129 100644 --- a/src/cpp/proto/proto_utils.cc +++ b/src/cpp/proto/proto_utils.cc @@ -103,7 +103,9 @@ class GrpcBufferReader GRPC_FINAL : byte_count_(0), backup_count_(0) { grpc_byte_buffer_reader_init(&reader_, buffer); } - ~GrpcBufferReader() GRPC_OVERRIDE {} + ~GrpcBufferReader() GRPC_OVERRIDE { + grpc_byte_buffer_reader_destroy(&reader_); + } bool Next(const void** data, int* size) GRPC_OVERRIDE { if (backup_count_ > 0) { diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index 699895a3cfcf7e253b9dfe274e99b6a7b1f17fcf..087e28d33ab5725e243fd050f16af17fb2ad1bac 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -39,6 +39,8 @@ #include <grpc++/impl/sync.h> #include <grpc++/time.h> +#include "src/core/channel/compress_filter.h" + namespace grpc { // CompletionOp @@ -146,4 +148,22 @@ bool ServerContext::IsCancelled() { return completion_op_ && completion_op_->CheckCancelled(cq_); } +void ServerContext::set_compression_level(grpc_compression_level level) { + const grpc_compression_algorithm algorithm_for_level = + grpc_compression_algorithm_for_level(level); + set_compression_algorithm(algorithm_for_level); +} + +void ServerContext::set_compression_algorithm( + grpc_compression_algorithm algorithm) { + char* algorithm_name = NULL; + if (!grpc_compression_algorithm_name(algorithm, &algorithm_name)) { + gpr_log(GPR_ERROR, "Name for compression algorithm '%d' unknown.", + algorithm); + abort(); + } + GPR_ASSERT(algorithm_name != NULL); + AddInitialMetadata(GRPC_COMPRESS_REQUEST_ALGORITHM_KEY, algorithm_name); +} + } // namespace grpc diff --git a/test/core/end2end/tests/request_with_compressed_payload.c b/test/core/end2end/tests/request_with_compressed_payload.c index ca16bc7d521d9f2f9e7364c027690ad60bede54a..a6057457c4a5b40d983e5e9f232b7127781ed107 100644 --- a/test/core/end2end/tests/request_with_compressed_payload.c +++ b/test/core/end2end/tests/request_with_compressed_payload.c @@ -45,6 +45,7 @@ #include "test/core/end2end/cq_verifier.h" #include "src/core/channel/channel_args.h" +#include "src/core/channel/compress_filter.h" enum { TIMEOUT = 200000 }; @@ -240,6 +241,7 @@ static void request_with_payload_template( cq_verifier_destroy(cqv); + gpr_slice_unref(request_payload_slice); grpc_byte_buffer_destroy(request_payload); grpc_byte_buffer_destroy(request_payload_recv); @@ -279,13 +281,13 @@ static void test_invoke_request_with_compressed_payload_md_override( grpc_metadata gzip_compression_override; grpc_metadata none_compression_override; - gzip_compression_override.key = "grpc-encoding"; + gzip_compression_override.key = GRPC_COMPRESS_REQUEST_ALGORITHM_KEY; gzip_compression_override.value = "gzip"; gzip_compression_override.value_length = 4; memset(&gzip_compression_override.internal_data, 0, sizeof(gzip_compression_override.internal_data)); - none_compression_override.key = "grpc-encoding"; + none_compression_override.key = GRPC_COMPRESS_REQUEST_ALGORITHM_KEY; none_compression_override.value = "none"; none_compression_override.value_length = 4; memset(&none_compression_override.internal_data, 0, diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc index 45ba8b0878b4f64d1bee61a65aaa57dc31a44b59..49070a7df1a7d7af8820313c380fbb1b5134cf36 100644 --- a/test/cpp/end2end/end2end_test.cc +++ b/test/cpp/end2end/end2end_test.cc @@ -226,10 +226,11 @@ static void SendRpc(grpc::cpp::test::util::TestService::Stub* stub, int num_rpcs) { EchoRequest request; EchoResponse response; - request.set_message("Hello"); + request.set_message("Hello hello hello hello"); for (int i = 0; i < num_rpcs; ++i) { ClientContext context; + context.set_compression_level(GRPC_COMPRESS_LEVEL_HIGH); Status s = stub->Echo(&context, request, &response); EXPECT_EQ(response.message(), request.message()); EXPECT_TRUE(s.ok()); diff --git a/test/cpp/end2end/generic_end2end_test.cc b/test/cpp/end2end/generic_end2end_test.cc index b9d47b32de53131e2b50ee4f1f5c00af4497f983..e9d86cc9f758975f7e361daa52adeb9ad3f68a7a 100644 --- a/test/cpp/end2end/generic_end2end_test.cc +++ b/test/cpp/end2end/generic_end2end_test.cc @@ -227,6 +227,7 @@ TEST_F(GenericEnd2endTest, SimpleBidiStreaming) { GenericServerContext srv_ctx; GenericServerAsyncReaderWriter srv_stream(&srv_ctx); + cli_ctx.set_compression_level(GRPC_COMPRESS_LEVEL_HIGH); send_request.set_message("Hello"); std::unique_ptr<GenericClientAsyncReaderWriter> cli_stream = generic_stub_->Call(&cli_ctx, kMethodName, &cli_cq_, tag(1));