diff --git a/include/grpc++/server_context.h b/include/grpc++/server_context.h index 64091a4505ddb3413fab0d433086ed1a034bc952..423ebf2337737dec05bb7d3f31541219c23bb04b 100644 --- a/include/grpc++/server_context.h +++ b/include/grpc++/server_context.h @@ -45,7 +45,7 @@ struct grpc_call; namespace grpc { -template <class R> +template <class W, class R> class ServerAsyncReader; template <class W> class ServerAsyncWriter; diff --git a/include/grpc++/stream.h b/include/grpc++/stream.h index ecc28f62160f986832d1851ec142af816909e880..6ee550bd644d9e8b82a4b7b2e16bdf862f641870 100644 --- a/include/grpc++/stream.h +++ b/include/grpc++/stream.h @@ -615,7 +615,7 @@ class ServerAsyncResponseWriter final : public ServerAsyncStreamingInterface { CallOpBuffer finish_buf_; }; -template <class R> +template <class W, class R> class ServerAsyncReader : public ServerAsyncStreamingInterface, public AsyncReaderInterface<R> { public: @@ -637,18 +637,34 @@ class ServerAsyncReader : public ServerAsyncStreamingInterface, call_.PerformOps(&read_buf_); } - void Finish(const Status& status, void* tag) { + void Finish(const W& msg, const Status& status, void* tag) { finish_buf_.Reset(tag); if (!ctx_->sent_initial_metadata_) { finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_); ctx_->sent_initial_metadata_ = true; } + // The response is dropped if the status is not OK. + if (status.IsOk()) { + finish_buf_.AddSendMessage(msg); + } bool cancelled = false; finish_buf_.AddServerRecvClose(&cancelled); finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_buf_); } + void FinishWithError(const Status& status, void* tag) { + GPR_ASSERT(!status.IsOk()); + finish_buf_.Reset(tag); + if (!ctx_->sent_initial_metadata_) { + finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_); + ctx_->sent_initial_metadata_ = true; + } + bool cancelled = false; + finish_buf_.AddServerRecvClose(&cancelled); + finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status); + call_.PerformOps(&finish_buf_); + } private: void BindCall(Call *call) override { call_ = *call; } diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc index a34aa4e568ea19c6c8ac24734bdba7ba6801c76e..2a9895e43c4e04906926a01748595a7dac277037 100644 --- a/src/compiler/cpp_generator.cc +++ b/src/compiler/cpp_generator.cc @@ -133,7 +133,7 @@ std::string GetHeaderIncludes(const google::protobuf::FileDescriptor *file) { temp.append("template <class OutMessage> class ClientWriter;\n"); temp.append("template <class InMessage> class ServerReader;\n"); temp.append("template <class OutMessage> class ClientAsyncWriter;\n"); - temp.append("template <class InMessage> class ServerAsyncReader;\n"); + temp.append("template <class OutMessage, class InMessage> class ServerAsyncReader;\n"); } if (HasServerOnlyStreaming(file)) { temp.append("template <class InMessage> class ClientReader;\n"); @@ -267,7 +267,7 @@ void PrintHeaderServerMethodAsync( printer->Print(*vars, "void Request$Method$(" "::grpc::ServerContext* context, " - "::grpc::ServerAsyncReader< $Request$>* reader, " + "::grpc::ServerAsyncReader< $Response$, $Request$>* reader, " "::grpc::CompletionQueue* cq, void *tag);\n"); } else if (ServerOnlyStreaming(method)) { printer->Print(*vars, @@ -538,7 +538,7 @@ void PrintSourceServerAsyncMethod( printer->Print(*vars, "void $Service$::AsyncService::Request$Method$(" "::grpc::ServerContext* context, " - "::grpc::ServerAsyncReader< $Request$>* reader, " + "::grpc::ServerAsyncReader< $Response$, $Request$>* reader, " "::grpc::CompletionQueue* cq, void* tag) {\n"); printer->Print( *vars, diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc index 62c7e40ed2ad892fcd7a4f80f80dc8b7e5c874e4..b85aabf09e11b41f23f17a0f6e7e1b28c0763a06 100644 --- a/test/cpp/end2end/async_end2end_test.cc +++ b/test/cpp/end2end/async_end2end_test.cc @@ -110,6 +110,7 @@ class End2endTest : public ::testing::Test { void client_fail(int i) { verify_ok(&cli_cq_, i, false); } + CompletionQueue cli_cq_; CompletionQueue srv_cq_; std::unique_ptr<grpc::cpp::test::util::TestService::Stub> stub_; @@ -151,6 +152,59 @@ TEST_F(End2endTest, SimpleRpc) { EXPECT_TRUE(recv_status.IsOk()); } +TEST_F(End2endTest, SimpleClientStreaming) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + send_request.set_message("Hello"); + ClientAsyncWriter<EchoRequest>* cli_stream = + stub_->RequestStream(&cli_ctx, &recv_response, &cli_cq_, tag(1)); + + service_.RequestRequestStream( + &srv_ctx, &srv_stream, &srv_cq_, tag(2)); + + server_ok(2); + client_ok(1); + + cli_stream->Write(send_request, tag(3)); + client_ok(3); + + srv_stream.Read(&recv_request, tag(4)); + server_ok(4); + EXPECT_EQ(send_request.message(), recv_request.message()); + + cli_stream->Write(send_request, tag(5)); + client_ok(5); + + srv_stream.Read(&recv_request, tag(6)); + server_ok(6); + + EXPECT_EQ(send_request.message(), recv_request.message()); + cli_stream->WritesDone(tag(7)); + client_ok(7); + + srv_stream.Read(&recv_request, tag(8)); + server_fail(8); + + send_response.set_message(recv_request.message()); + srv_stream.Finish(send_response, Status::OK, tag(9)); + server_ok(9); + + cli_stream->Finish(&recv_status, tag(10)); + client_ok(10); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.IsOk()); +} + TEST_F(End2endTest, SimpleBidiStreaming) { ResetStub();