diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc index 0616cc07eeb6b882ebe906a0fa4ce679b9396426..0179894ec097ad0790c4f236cea8880071e6911b 100644 --- a/test/cpp/end2end/async_end2end_test.cc +++ b/test/cpp/end2end/async_end2end_test.cc @@ -32,6 +32,7 @@ */ #include <memory> +#include <thread> #include <grpc++/channel.h> #include <grpc++/client_context.h> @@ -104,7 +105,10 @@ class Verifier : public PollingCheckRegion { expectations_[tag(i)] = expect_ok; return *this; } - void Verify(CompletionQueue* cq) { + + void Verify(CompletionQueue* cq) { Verify(cq, false); } + + void Verify(CompletionQueue* cq, bool ignore_ok) { GPR_ASSERT(!expectations_.empty()); while (!expectations_.empty()) { bool ok; @@ -122,7 +126,9 @@ class Verifier : public PollingCheckRegion { } auto it = expectations_.find(got_tag); EXPECT_TRUE(it != expectations_.end()); - EXPECT_EQ(it->second, ok); + if (!ignore_ok) { + EXPECT_EQ(it->second, ok); + } expectations_.erase(it); } } @@ -217,7 +223,7 @@ class AsyncEnd2endTest : public ::testing::TestWithParam<bool> { grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); send_request.set_message("Hello"); - std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader( + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), @@ -270,7 +276,7 @@ TEST_P(AsyncEnd2endTest, AsyncNextRpc) { grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); send_request.set_message("Hello"); - std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader( + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); std::chrono::system_clock::time_point time_now( @@ -315,7 +321,7 @@ TEST_P(AsyncEnd2endTest, SimpleClientStreaming) { ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx); send_request.set_message("Hello"); - std::unique_ptr<ClientAsyncWriter<EchoRequest> > cli_stream( + std::unique_ptr<ClientAsyncWriter<EchoRequest>> cli_stream( stub_->AsyncRequestStream(&cli_ctx, &recv_response, cq_.get(), tag(1))); service_.RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), @@ -368,7 +374,7 @@ TEST_P(AsyncEnd2endTest, SimpleServerStreaming) { ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx); send_request.set_message("Hello"); - std::unique_ptr<ClientAsyncReader<EchoResponse> > cli_stream( + std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream( stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); service_.RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, @@ -418,7 +424,7 @@ TEST_P(AsyncEnd2endTest, SimpleBidiStreaming) { ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx); send_request.set_message("Hello"); - std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse> > + std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>> cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1))); service_.RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), @@ -476,7 +482,7 @@ TEST_P(AsyncEnd2endTest, ClientInitialMetadataRpc) { cli_ctx.AddMetadata(meta1.first, meta1.second); cli_ctx.AddMetadata(meta2.first, meta2.second); - std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader( + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), @@ -519,7 +525,7 @@ TEST_P(AsyncEnd2endTest, ServerInitialMetadataRpc) { std::pair<grpc::string, grpc::string> meta1("key1", "val1"); std::pair<grpc::string, grpc::string> meta2("key2", "val2"); - std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader( + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), @@ -568,7 +574,7 @@ TEST_P(AsyncEnd2endTest, ServerTrailingMetadataRpc) { std::pair<grpc::string, grpc::string> meta1("key1", "val1"); std::pair<grpc::string, grpc::string> meta2("key2", "val2"); - std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader( + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), @@ -629,7 +635,7 @@ TEST_P(AsyncEnd2endTest, MetadataRpc) { cli_ctx.AddMetadata(meta1.first, meta1.second); cli_ctx.AddMetadata(meta2.first, meta2.second); - std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader( + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), @@ -690,7 +696,7 @@ TEST_P(AsyncEnd2endTest, ServerCheckCancellation) { grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); send_request.set_message("Hello"); - std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader( + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); srv_ctx.AsyncNotifyWhenDone(tag(5)); @@ -725,7 +731,7 @@ TEST_P(AsyncEnd2endTest, ServerCheckDone) { grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); send_request.set_message("Hello"); - std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader( + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); srv_ctx.AsyncNotifyWhenDone(tag(5)); @@ -759,7 +765,7 @@ TEST_P(AsyncEnd2endTest, UnimplementedRpc) { ClientContext cli_ctx; send_request.set_message("Hello"); - std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader( + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( stub->AsyncUnimplemented(&cli_ctx, send_request, cq_.get())); response_reader->Finish(&recv_response, &recv_status, tag(4)); @@ -769,8 +775,339 @@ TEST_P(AsyncEnd2endTest, UnimplementedRpc) { EXPECT_EQ("", recv_status.error_message()); } +class AsyncEnd2endServerTryCancelTest : public AsyncEnd2endTest { + protected: + typedef enum { + DO_NOT_CANCEL = 0, + CANCEL_BEFORE_PROCESSING, + CANCEL_DURING_PROCESSING, + CANCEL_AFTER_PROCESSING + } ServerTryCancelRequestPhase; + + void ServerTryCancel(ServerContext* context) { + EXPECT_FALSE(context->IsCancelled()); + context->TryCancel(); + gpr_log(GPR_INFO, "Server called TryCancel()"); + EXPECT_TRUE(context->IsCancelled()); + } + + void TestClientStreamingServerCancel( + ServerTryCancelRequestPhase server_try_cancel) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + CompletionQueue cli_cq; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + // Initiate the 'RequestStream' call on client + std::unique_ptr<ClientAsyncWriter<EchoRequest>> cli_stream( + stub_->AsyncRequestStream(&cli_ctx, &recv_response, &cli_cq, tag(1))); + Verifier(GetParam()).Expect(1, true).Verify(&cli_cq); + + // On the server, request to be notified of 'RequestStream' calls + // and receive the 'RequestStream' call just made by the client + service_.RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + Verifier(GetParam()).Expect(2, true).Verify(cq_.get()); + + // Client sends 3 messages (tags 3, 4 and 5) + for (int tag_idx = 3; tag_idx <= 5; tag_idx++) { + send_request.set_message("Ping " + std::to_string(tag_idx)); + cli_stream->Write(send_request, tag(tag_idx)); + Verifier(GetParam()).Expect(tag_idx, true).Verify(&cli_cq); + } + cli_stream->WritesDone(tag(6)); + Verifier(GetParam()).Expect(6, true).Verify(&cli_cq); + + bool expected_server_cq_result = true; + bool ignore_cq_result = false; + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + ServerTryCancel(&srv_ctx); + + // Since cancellation is done before server reads any results, we know + // for sure that all cq results will return false from this point forward + expected_server_cq_result = false; + } + + std::thread* server_try_cancel_thd = NULL; + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = new std::thread( + &AsyncEnd2endServerTryCancelTest::ServerTryCancel, this, &srv_ctx); + // Server will cancel the RPC in a parallel thread while reading the + // requests from the client. Since the cancellation can happen at anytime, + // some of the cq results (i.e those until cancellation) might be true but + // its non deterministic. So better to ignore the cq results + ignore_cq_result = true; + } + + // Server reads 3 messages (tags 6, 7 and 8) + for (int tag_idx = 6; tag_idx <= 8; tag_idx++) { + srv_stream.Read(&recv_request, tag(tag_idx)); + Verifier(GetParam()) + .Expect(tag_idx, expected_server_cq_result) + .Verify(cq_.get(), ignore_cq_result); + } + + if (server_try_cancel_thd != NULL) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + ServerTryCancel(&srv_ctx); + } + + // Note: The RPC has been cancelled at this point for sure. So, from this + // point forward, we know that cq results are supposed to return false on + // server. + + send_response.set_message("Pong"); + srv_stream.Finish(send_response, Status::CANCELLED, tag(9)); + Verifier(GetParam()).Expect(9, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(10)); + // TODO: sreek: The expectation here should be true. This seems like a bug. + // Investigating + Verifier(GetParam()).Expect(10, false).Verify(&cli_cq); + EXPECT_FALSE(recv_status.ok()); + EXPECT_EQ(::grpc::StatusCode::CANCELLED, recv_status.error_code()); + } + + void TestServerStreamingServerCancel( + ServerTryCancelRequestPhase server_try_cancel) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + CompletionQueue cli_cq; + ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx); + + send_request.set_message("Ping"); + // Initiate the 'ResponseStream' call on the client + std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream( + stub_->AsyncResponseStream(&cli_ctx, send_request, &cli_cq, tag(1))); + Verifier(GetParam()).Expect(1, true).Verify(&cli_cq); + + // On the server, request to be notified of 'ResponseStream' calls and + // receive the call just made by the client + service_.RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + Verifier(GetParam()).Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + bool expected_cq_result = true; + bool ignore_cq_result = false; + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + ServerTryCancel(&srv_ctx); + + // We know for sure that all cq results will be false from this point + // since the server cancelled the RPC + expected_cq_result = false; + } + + std::thread* server_try_cancel_thd = NULL; + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = new std::thread( + &AsyncEnd2endServerTryCancelTest::ServerTryCancel, this, &srv_ctx); + + // Server will cancel the RPC in a parallel thread while writing responses + // to the client. Since the cancellation can happen at anytime, some of + // the cq results (i.e those until cancellation) might be true but + // its non deterministic. So better to ignore the cq results + ignore_cq_result = true; + } + + // Server sends three messages (tags 3, 4 and 5) + for (int tag_idx = 3; tag_idx <= 5; tag_idx++) { + send_response.set_message("Pong " + std::to_string(tag_idx)); + srv_stream.Write(send_response, tag(tag_idx)); + Verifier(GetParam()) + .Expect(tag_idx, expected_cq_result) + .Verify(cq_.get(), ignore_cq_result); + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + ServerTryCancel(&srv_ctx); + } + + // Client attemts to read the three messages + for (int tag_idx = 6; tag_idx <= 8; tag_idx++) { + cli_stream->Read(&recv_response, tag(tag_idx)); + Verifier(GetParam()) + .Expect(tag_idx, expected_cq_result) + .Verify(&cli_cq, ignore_cq_result); + } + + if (server_try_cancel_thd != NULL) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + } + + // Note: At this point, we know that server has cancelled the request for + // sure. + + // Server finishes the stream + srv_stream.Finish(Status::CANCELLED, tag(9)); + Verifier(GetParam()).Expect(9, false).Verify(cq_.get()); + + // Client receives the cancellation + cli_stream->Finish(&recv_status, tag(10)); + Verifier(GetParam()).Expect(10, true).Verify(&cli_cq); + EXPECT_FALSE(recv_status.ok()); + EXPECT_EQ(::grpc::StatusCode::CANCELLED, recv_status.error_code()); + } + + void TestBidiStreamingServerCancel( + ServerTryCancelRequestPhase server_try_cancel) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + CompletionQueue cli_cq; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + // Initiate the call from the client side + std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>> + cli_stream(stub_->AsyncBidiStream(&cli_ctx, &cli_cq, tag(1))); + Verifier(GetParam()).Expect(1, true).Verify(&cli_cq); + + // On the server, request to be notified of the 'BidiStream' call and + // receive the call just made by the client + service_.RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + Verifier(GetParam()).Expect(2, true).Verify(cq_.get()); + + send_request.set_message("Ping"); + cli_stream->Write(send_request, tag(3)); + Verifier(GetParam()).Expect(3, true).Verify(&cli_cq); + + bool expected_cq_result = true; + bool ignore_cq_result = false; + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + ServerTryCancel(&srv_ctx); + + // We know for sure that all cq results will be false from this point + // since the server cancelled the RPC + expected_cq_result = false; + } + + std::thread* server_try_cancel_thd = NULL; + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = new std::thread( + &AsyncEnd2endServerTryCancelTest::ServerTryCancel, this, &srv_ctx); + + // Since server is going to cancel the RPC in a parallel thread, some of + // the cq results (i.e those until the cancellation) might be true. Since + // that number is non-deterministic, it is better to ignore the cq results + ignore_cq_result = true; + } + + srv_stream.Read(&recv_request, tag(4)); + Verifier(GetParam()) + .Expect(4, expected_cq_result) + .Verify(cq_.get(), ignore_cq_result); + + send_response.set_message("Pong"); + srv_stream.Write(send_response, tag(5)); + Verifier(GetParam()) + .Expect(5, expected_cq_result) + .Verify(cq_.get(), ignore_cq_result); + + cli_stream->Read(&recv_response, tag(6)); + Verifier(GetParam()) + .Expect(6, expected_cq_result) + .Verify(&cli_cq, ignore_cq_result); + + // This is expected to succeed in all cases + cli_stream->WritesDone(tag(7)); + Verifier(GetParam()).Expect(7, true).Verify(&cli_cq); + + // This is expected to fail in all cases (Either there are no more msgs from + // the client or the RPC is cancelled on the server) + srv_stream.Read(&recv_request, tag(8)); + Verifier(GetParam()).Expect(8, false).Verify(cq_.get()); + + if (server_try_cancel_thd != NULL) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + ServerTryCancel(&srv_ctx); + } + + // At this point, we know that the server cancelled the request for sure + + srv_stream.Finish(Status::CANCELLED, tag(9)); + Verifier(GetParam()).Expect(9, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(10)); + Verifier(GetParam()).Expect(10, true).Verify(&cli_cq); + EXPECT_FALSE(recv_status.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, recv_status.error_code()); + } +}; + +TEST_P(AsyncEnd2endServerTryCancelTest, ClientStreamingServerTryCancelBefore) { + TestClientStreamingServerCancel(CANCEL_BEFORE_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ClientStreamingServerTryCancelDuring) { + TestClientStreamingServerCancel(CANCEL_DURING_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ClientStreamingServerTryCancelAfter) { + TestClientStreamingServerCancel(CANCEL_AFTER_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerStreamingServerTryCancelBefore) { + TestServerStreamingServerCancel(CANCEL_BEFORE_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerStreamingServerTryCancelDuring) { + TestServerStreamingServerCancel(CANCEL_DURING_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerStreamingServerTryCancelAfter) { + TestServerStreamingServerCancel(CANCEL_AFTER_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerBidiStreamingTryCancelBefore) { + TestBidiStreamingServerCancel(CANCEL_BEFORE_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerBidiStreamingTryCancelDuring) { + TestBidiStreamingServerCancel(CANCEL_DURING_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerBidiStreamingTryCancelAfter) { + TestBidiStreamingServerCancel(CANCEL_AFTER_PROCESSING); +} + INSTANTIATE_TEST_CASE_P(AsyncEnd2end, AsyncEnd2endTest, ::testing::Values(false, true)); +INSTANTIATE_TEST_CASE_P(AsyncEnd2endServerTryCancel, + AsyncEnd2endServerTryCancelTest, + ::testing::Values(false)); } // namespace } // namespace testing diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc index 5a414ebc86c114317824906d6babf03d05c4f48c..48838901d27bf11dd27900911462a0fe3689e37e 100644 --- a/test/cpp/end2end/end2end_test.cc +++ b/test/cpp/end2end/end2end_test.cc @@ -65,6 +65,14 @@ namespace testing { namespace { const char* kServerCancelAfterReads = "cancel_after_reads"; +const char* kServerTryCancelRequest = "server_try_cancel"; +typedef enum { + DO_NOT_CANCEL = 0, + CANCEL_BEFORE_PROCESSING, + CANCEL_DURING_PROCESSING, + CANCEL_AFTER_PROCESSING +} ServerTryCancelRequestPhase; +const int kNumResponseStreamsMsgs = 3; // When echo_deadline is requested, deadline seen in the ServerContext is set in // the response in seconds. @@ -218,8 +226,37 @@ class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { explicit TestServiceImpl(const grpc::string& host) : signal_client_(false), host_(new grpc::string(host)) {} + int GetIntValueFromMetadata( + const char* key, + const std::multimap<grpc::string_ref, grpc::string_ref>& metadata, + int default_value) { + if (metadata.find(key) != metadata.end()) { + std::istringstream iss(ToString(metadata.find(key)->second)); + iss >> default_value; + gpr_log(GPR_INFO, "%s : %d", key, default_value); + } + + return default_value; + } + + void ServerTryCancel(ServerContext* context) { + EXPECT_FALSE(context->IsCancelled()); + context->TryCancel(); + gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request"); + EXPECT_TRUE(context->IsCancelled()); + } + Status Echo(ServerContext* context, const EchoRequest* request, EchoResponse* response) GRPC_OVERRIDE { + int server_try_cancel = GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + if (server_try_cancel > DO_NOT_CANCEL) { + // For unary RPC, the actual value of server_try_cancel does not matter + // (as long as it is greater than DO_NOT_CANCEL) + ServerTryCancel(context); + return Status::CANCELLED; + } + response->set_message(request->message()); MaybeEchoDeadline(context, request, response); if (host_) { @@ -283,17 +320,25 @@ class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { EchoResponse* response) GRPC_OVERRIDE { EchoRequest request; response->set_message(""); - int cancel_after_reads = 0; - const std::multimap<grpc::string_ref, grpc::string_ref>& - client_initial_metadata = context->client_metadata(); - if (client_initial_metadata.find(kServerCancelAfterReads) != - client_initial_metadata.end()) { - std::istringstream iss(ToString( - client_initial_metadata.find(kServerCancelAfterReads)->second)); - iss >> cancel_after_reads; - gpr_log(GPR_INFO, "cancel_after_reads %d", cancel_after_reads); + int cancel_after_reads = GetIntValueFromMetadata( + kServerCancelAfterReads, context->client_metadata(), 0); + int server_try_cancel = GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + ServerTryCancel(context); + return Status::CANCELLED; + } + + std::thread* server_try_cancel_thd = NULL; + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = + new std::thread(&TestServiceImpl::ServerTryCancel, this, context); } + + int num_msgs_read = 0; while (reader->Read(&request)) { + num_msgs_read++; if (cancel_after_reads == 1) { gpr_log(GPR_INFO, "return cancel status"); return Status::CANCELLED; @@ -302,20 +347,56 @@ class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { } response->mutable_message()->append(request.message()); } + gpr_log(GPR_INFO, "Read: %d messages", num_msgs_read); + + if (server_try_cancel_thd != NULL) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + return Status::CANCELLED; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + ServerTryCancel(context); + return Status::CANCELLED; + } + return Status::OK; } - // Return 3 messages. + // Return 'kNumResponseStreamMsgs' messages. // TODO(yangg) make it generic by adding a parameter into EchoRequest Status ResponseStream(ServerContext* context, const EchoRequest* request, ServerWriter<EchoResponse>* writer) GRPC_OVERRIDE { + int server_try_cancel = GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + ServerTryCancel(context); + return Status::CANCELLED; + } + EchoResponse response; - response.set_message(request->message() + "0"); - writer->Write(response); - response.set_message(request->message() + "1"); - writer->Write(response); - response.set_message(request->message() + "2"); - writer->Write(response); + std::thread* server_try_cancel_thd = NULL; + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = + new std::thread(&TestServiceImpl::ServerTryCancel, this, context); + } + + for (int i = 0; i < kNumResponseStreamsMsgs; i++) { + response.set_message(request->message() + std::to_string(i)); + writer->Write(response); + } + + if (server_try_cancel_thd != NULL) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + return Status::CANCELLED; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + ServerTryCancel(context); + return Status::CANCELLED; + } return Status::OK; } @@ -325,11 +406,38 @@ class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { GRPC_OVERRIDE { EchoRequest request; EchoResponse response; + + int server_try_cancel = GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + ServerTryCancel(context); + return Status::CANCELLED; + } + + std::thread* server_try_cancel_thd = NULL; + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = + new std::thread(&TestServiceImpl::ServerTryCancel, this, context); + } + while (stream->Read(&request)) { gpr_log(GPR_INFO, "recv msg %s", request.message().c_str()); response.set_message(request.message()); stream->Write(response); } + + if (server_try_cancel_thd != NULL) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + return Status::CANCELLED; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + ServerTryCancel(context); + return Status::CANCELLED; + } + return Status::OK; } @@ -466,6 +574,231 @@ static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs) { } } +// == Tests for cancelling RPC from server side == + +class End2endServerTryCancelTest : public End2endTest { + protected: + // Tests for Client streaming + void TestRequestStreamServerCancel( + ServerTryCancelRequestPhase server_try_cancel, int num_msgs_to_send) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.AddMetadata(kServerTryCancelRequest, + std::to_string(server_try_cancel)); + + auto stream = stub_->RequestStream(&context, &response); + int num_msgs_sent = 0; + while (num_msgs_sent < num_msgs_to_send) { + request.set_message("hello"); + if (!stream->Write(request)) { + break; + } + num_msgs_sent++; + } + gpr_log(GPR_INFO, "Sent %d messages", num_msgs_sent); + stream->WritesDone(); + Status s = stream->Finish(); + + switch (server_try_cancel) { + case CANCEL_BEFORE_PROCESSING: + case CANCEL_DURING_PROCESSING: + EXPECT_LE(num_msgs_sent, num_msgs_to_send); + break; + case CANCEL_AFTER_PROCESSING: + EXPECT_EQ(num_msgs_sent, num_msgs_to_send); + break; + default: + gpr_log(GPR_ERROR, "Invalid server_try_cancel value: %d", + server_try_cancel); + EXPECT_TRUE(server_try_cancel > DO_NOT_CANCEL && + server_try_cancel <= CANCEL_AFTER_PROCESSING); + break; + } + + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + } + + // Test for server streaming + void TestResponseStreamServerCancel( + ServerTryCancelRequestPhase server_try_cancel) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.AddMetadata(kServerTryCancelRequest, + std::to_string(server_try_cancel)); + request.set_message("hello"); + auto stream = stub_->ResponseStream(&context, request); + + int num_msgs_read = 0; + while (num_msgs_read < kNumResponseStreamsMsgs) { + if (!stream->Read(&response)) { + break; + } + EXPECT_EQ(response.message(), + request.message() + std::to_string(num_msgs_read)); + num_msgs_read++; + } + gpr_log(GPR_INFO, "Read %d messages", num_msgs_read); + + Status s = stream->Finish(); + + switch (server_try_cancel) { + case CANCEL_BEFORE_PROCESSING: { + EXPECT_EQ(num_msgs_read, 0); + break; + } + case CANCEL_DURING_PROCESSING: { + EXPECT_LE(num_msgs_read, kNumResponseStreamsMsgs); + break; + } + case CANCEL_AFTER_PROCESSING: { + EXPECT_EQ(num_msgs_read, kNumResponseStreamsMsgs); + } + default: { + gpr_log(GPR_ERROR, "Invalid server_try_cancel value: %d", + server_try_cancel); + EXPECT_TRUE(server_try_cancel > DO_NOT_CANCEL && + server_try_cancel <= CANCEL_AFTER_PROCESSING); + break; + } + } + + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + } + + void TestBidiStreamServerCancel(ServerTryCancelRequestPhase server_try_cancel, + int num_messages) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.AddMetadata(kServerTryCancelRequest, + std::to_string(server_try_cancel)); + + auto stream = stub_->BidiStream(&context); + + int num_msgs_read = 0; + int num_msgs_sent = 0; + while (num_msgs_sent < num_messages) { + request.set_message("hello " + std::to_string(num_msgs_sent)); + if (!stream->Write(request)) { + break; + } + num_msgs_sent++; + + if (!stream->Read(&response)) { + break; + } + num_msgs_read++; + + EXPECT_EQ(response.message(), request.message()); + } + gpr_log(GPR_INFO, "Sent %d messages", num_msgs_sent); + gpr_log(GPR_INFO, "Read %d messages", num_msgs_read); + + stream->WritesDone(); + Status s = stream->Finish(); + + switch (server_try_cancel) { + case CANCEL_BEFORE_PROCESSING: { + EXPECT_EQ(num_msgs_read, 0); + break; + } + case CANCEL_DURING_PROCESSING: { + EXPECT_LE(num_msgs_sent, num_messages); + EXPECT_LE(num_msgs_read, num_msgs_sent); + break; + } + case CANCEL_AFTER_PROCESSING: { + EXPECT_EQ(num_msgs_sent, num_messages); + EXPECT_EQ(num_msgs_read, num_msgs_sent); + } + default: { + gpr_log(GPR_ERROR, "Invalid server_try_cancel value: %d", + server_try_cancel); + EXPECT_TRUE(server_try_cancel > DO_NOT_CANCEL && + server_try_cancel <= CANCEL_AFTER_PROCESSING); + break; + } + } + + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + } +}; + +TEST_P(End2endServerTryCancelTest, RequestEchoServerCancel) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.AddMetadata(kServerTryCancelRequest, + std::to_string(CANCEL_BEFORE_PROCESSING)); + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); +} + +// Server to cancel before doing reading the request +TEST_P(End2endServerTryCancelTest, RequestStreamServerCancelBeforeReads) { + TestRequestStreamServerCancel(CANCEL_BEFORE_PROCESSING, 1); +} + +// Server to cancel while reading a request from the stream in parallel +TEST_P(End2endServerTryCancelTest, RequestStreamServerCancelDuringRead) { + TestRequestStreamServerCancel(CANCEL_DURING_PROCESSING, 10); +} + +// Server to cancel after reading all the requests but before returning to the +// client +TEST_P(End2endServerTryCancelTest, RequestStreamServerCancelAfterReads) { + TestRequestStreamServerCancel(CANCEL_AFTER_PROCESSING, 4); +} + +// Server to cancel before sending any response messages +TEST_P(End2endServerTryCancelTest, ResponseStreamServerCancelBefore) { + TestResponseStreamServerCancel(CANCEL_BEFORE_PROCESSING); +} + +// Server to cancel while writing a response to the stream in parallel +TEST_P(End2endServerTryCancelTest, ResponseStreamServerCancelDuring) { + TestResponseStreamServerCancel(CANCEL_DURING_PROCESSING); +} + +// Server to cancel after writing all the respones to the stream but before +// returning to the client +TEST_P(End2endServerTryCancelTest, ResponseStreamServerCancelAfter) { + TestResponseStreamServerCancel(CANCEL_AFTER_PROCESSING); +} + +// Server to cancel before reading/writing any requests/responses on the stream +TEST_P(End2endServerTryCancelTest, BidiStreamServerCancelBefore) { + TestBidiStreamServerCancel(CANCEL_BEFORE_PROCESSING, 2); +} + +// Server to cancel while reading/writing requests/responses on the stream in +// parallel +TEST_P(End2endServerTryCancelTest, BidiStreamServerCancelDuring) { + TestBidiStreamServerCancel(CANCEL_DURING_PROCESSING, 10); +} + +// Server to cancel after reading/writing all requests/responses on the stream +// but before returning to the client +TEST_P(End2endServerTryCancelTest, BidiStreamServerCancelAfter) { + TestBidiStreamServerCancel(CANCEL_AFTER_PROCESSING, 5); +} + +// ===== + TEST_P(End2endTest, RequestStreamOneRequest) { ResetStub(); EchoRequest request; @@ -1195,6 +1528,9 @@ INSTANTIATE_TEST_CASE_P(End2end, End2endTest, ::testing::Values(TestScenario(false, false), TestScenario(false, true))); +INSTANTIATE_TEST_CASE_P(End2endServerTryCancel, End2endServerTryCancelTest, + ::testing::Values(TestScenario(false, false))); + INSTANTIATE_TEST_CASE_P(ProxyEnd2end, ProxyEnd2endTest, ::testing::Values(TestScenario(false, false), TestScenario(false, true),