diff --git a/include/grpc++/server_context.h b/include/grpc++/server_context.h index cf2732b33d250df2e8cf7b999b84b8a2e1d0b433..03d2f0d128d0c652e98e38da4803c77289a41bf8 100644 --- a/include/grpc++/server_context.h +++ b/include/grpc++/server_context.h @@ -125,6 +125,11 @@ class ServerContext { const struct census_context* census_context() const; + // Async only. Has to be called before the rpc starts. + // Returns the tag in completion queue when the rpc finishes. + // IsCancelled() can then be called to check whether the rpc was cancelled. + void AsyncNotifyWhenDone(void* tag) { async_notify_when_done_tag_ = tag; } + private: friend class ::grpc::testing::InteropContextInspector; friend class ::grpc::Server; @@ -165,6 +170,7 @@ class ServerContext { void set_call(grpc_call* call); CompletionOp* completion_op_; + void* async_notify_when_done_tag_; gpr_timespec deadline_; grpc_call* call_; diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index cf19556e7a08ed974f9e2f3cc094572dd493bbb3..0d09519b28d27671d2df3664b5b60b9eb7d8abc8 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -57,9 +57,12 @@ class ServerContext::CompletionOp GRPC_FINAL : public CallOpSetInterface { bool CheckCancelled(CompletionQueue* cq); + void set_tag(void* tag) { tag_ = tag; } + void Unref(); private: + void* tag_; grpc::mutex mu_; int refs_; bool finalized_; @@ -90,18 +93,24 @@ void ServerContext::CompletionOp::FillOps(grpc_op* ops, size_t* nops) { bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) { grpc::unique_lock<grpc::mutex> lock(mu_); finalized_ = true; + bool ret = false; + if (tag_) { + *tag = tag_; + ret = true; + } if (!*status) cancelled_ = 1; if (--refs_ == 0) { lock.unlock(); delete this; } - return false; + return ret; } // ServerContext body ServerContext::ServerContext() : completion_op_(nullptr), + async_notify_when_done_tag_(nullptr), call_(nullptr), cq_(nullptr), sent_initial_metadata_(false) {} @@ -109,6 +118,7 @@ ServerContext::ServerContext() ServerContext::ServerContext(gpr_timespec deadline, grpc_metadata* metadata, size_t metadata_count) : completion_op_(nullptr), + async_notify_when_done_tag_(nullptr), deadline_(deadline), call_(nullptr), cq_(nullptr), @@ -133,6 +143,7 @@ ServerContext::~ServerContext() { void ServerContext::BeginCompletionOp(Call* call) { GPR_ASSERT(!completion_op_); completion_op_ = new CompletionOp(); + completion_op_->set_tag(async_notify_when_done_tag_); call->PerformOps(completion_op_); } diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc index b95bdf6b9b458c703a4cc3997a8b0ea93f52e749..9b53bdc9990c99e58e6097fc775075e413cab01e 100644 --- a/test/cpp/end2end/async_end2end_test.cc +++ b/test/cpp/end2end/async_end2end_test.cc @@ -592,6 +592,80 @@ TEST_F(AsyncEnd2endTest, MetadataRpc) { EXPECT_EQ(meta6.second, server_trailing_metadata.find(meta6.first)->second); EXPECT_GE(server_trailing_metadata.size(), static_cast<size_t>(2)); } + +// Server uses AsyncNotifyWhenDone API to check for cancellation +TEST_F(AsyncEnd2endTest, ServerCheckCancellation) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message("Hello"); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + + srv_ctx.AsyncNotifyWhenDone(tag(5)); + service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + cli_ctx.TryCancel(); + Verifier().Expect(5, true).Verify(cq_.get()); + EXPECT_TRUE(srv_ctx.IsCancelled()); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + Verifier().Expect(4, false).Verify(cq_.get()); + + EXPECT_EQ(StatusCode::CANCELLED, recv_status.error_code()); +} + +// Server uses AsyncNotifyWhenDone API to check for normal finish +TEST_F(AsyncEnd2endTest, ServerCheckDone) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message("Hello"); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse> > response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + + srv_ctx.AsyncNotifyWhenDone(tag(5)); + service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Verify(cq_.get()); + Verifier().Expect(5, true).Verify(cq_.get()); + EXPECT_FALSE(srv_ctx.IsCancelled()); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + Verifier().Expect(4, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + } // namespace } // namespace testing } // namespace grpc