diff --git a/include/grpc++/server.h b/include/grpc++/server.h index 4784bace1c688f37c928e69277afa80c2d5473ae..a6883e24e68150cf06478160fc8ec6047feb810e 100644 --- a/include/grpc++/server.h +++ b/include/grpc++/server.h @@ -101,40 +101,82 @@ class Server GRPC_FINAL : public GrpcLibrary, class BaseAsyncRequest : public CompletionQueueTag { public: - BaseAsyncRequest(Server* server, - ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - ServerCompletionQueue* notification_cq, void* tag); - - private: + BaseAsyncRequest(Server* server, ServerContext* context, + ServerAsyncStreamingInterface* stream, + CompletionQueue* call_cq, + void* tag); + virtual ~BaseAsyncRequest(); + + bool FinalizeResult(void** tag, bool* status) GRPC_OVERRIDE; + + protected: + void FinalizeMetadata(ServerContext* context); + + Server* const server_; + ServerContext* const context_; + ServerAsyncStreamingInterface* const stream_; + CompletionQueue* const call_cq_; + grpc_call* call_; + grpc_metadata_array initial_metadata_array_; }; class RegisteredAsyncRequest : public BaseAsyncRequest { public: RegisteredAsyncRequest(Server* server, ServerContext* context, - ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - ServerCompletionQueue* notification_cq, void* tag) - : BaseAsyncRequest(server, stream, call_cq, notification_cq, tag) {} + ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, void* tag); + + // uses BaseAsyncRequest::FinalizeResult + + protected: + void IssueRequest(void* registered_method, grpc_byte_buffer** payload, ServerCompletionQueue *notification_cq); }; - class NoPayloadAsyncRequest : public RegisteredAsyncRequest { + class NoPayloadAsyncRequest GRPC_FINAL : public RegisteredAsyncRequest { public: - NoPayloadAsyncRequest(Server* server, ServerContext* context, + NoPayloadAsyncRequest(void* registered_method, Server* server, ServerContext* context, ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) - : RegisteredAsyncRequest(server, context, stream, call_cq, notification_cq, tag) { + : RegisteredAsyncRequest(server, context, stream, call_cq, tag) { + IssueRequest(registered_method, nullptr, notification_cq); } + + // uses RegisteredAsyncRequest::FinalizeResult }; template <class Message> - class PayloadAsyncRequest : public RegisteredAsyncRequest { - PayloadAsyncRequest(Server* server, ServerContext* context, + class PayloadAsyncRequest GRPC_FINAL : public RegisteredAsyncRequest { + public: + PayloadAsyncRequest(void* registered_method, Server* server, ServerContext* context, ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - ServerCompletionQueue* notification_cq, void* tag) - : RegisteredAsyncRequest(server, context, stream, call_cq, notification_cq, tag) { + ServerCompletionQueue* notification_cq, void* tag, Message* request) + : RegisteredAsyncRequest(server, context, stream, call_cq, tag), request_(request) { + IssueRequest(registered_method, &payload_, notification_cq); + } + + bool FinalizeResult(void** tag, bool* status) GRPC_OVERRIDE { + bool serialization_status = *status && payload_ && SerializationTraits<Message>::Deserialize(payload_, request_, server_->max_message_size_).IsOk(); + bool ret = RegisteredAsyncRequest::FinalizeResult(tag, status); + *status = serialization_status && *status; + return ret; } + + private: + grpc_byte_buffer* payload_; + Message* const request_; }; - class GenericAsyncRequest : public BaseAsyncRequest { + class GenericAsyncRequest GRPC_FINAL : public BaseAsyncRequest { + public: + GenericAsyncRequest(Server* server, GenericServerContext* context, + ServerAsyncStreamingInterface* stream, + CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, + void* tag); + + bool FinalizeResult(void** tag, bool* status) GRPC_OVERRIDE; + + private: + grpc_call_details call_details_; }; template <class Message> @@ -142,19 +184,25 @@ class Server GRPC_FINAL : public GrpcLibrary, ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, - void* tag, Message *message); + void* tag, Message *message) { + new PayloadAsyncRequest<Message>(registered_method, this, context, stream, call_cq, notification_cq, tag, message); + } void RequestAsyncCall(void* registered_method, ServerContext* context, ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, - void* tag); + void* tag) { + new NoPayloadAsyncRequest(registered_method, this, context, stream, call_cq, notification_cq, tag); + } void RequestAsyncGenericCall(GenericServerContext* context, ServerAsyncStreamingInterface* stream, - CompletionQueue* cq, + CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, - void* tag); + void* tag) { + new GenericAsyncRequest(this, context, stream, call_cq, notification_cq, tag); + } const int max_message_size_; diff --git a/src/cpp/server/server.cc b/src/cpp/server/server.cc index c08506c97ffcef671d6ac081bde4d2a3cc9570f3..bd97d707a7909a1380eb89e15af5508dee0e8456 100644 --- a/src/cpp/server/server.cc +++ b/src/cpp/server/server.cc @@ -422,6 +422,69 @@ void Server::RequestAsyncGenericCall(GenericServerContext* context, } #endif +Server::BaseAsyncRequest::BaseAsyncRequest(Server* server, ServerContext* context, + ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, void* tag) +: server_(server), context_(context), stream_(stream), call_cq_(call_cq), call_(nullptr) { + memset(&initial_metadata_array_, 0, sizeof(initial_metadata_array_)); +} + +Server::BaseAsyncRequest::~BaseAsyncRequest() { +} + +bool Server::BaseAsyncRequest::FinalizeResult(void** tag, bool* status) { + if (*status) { + for (size_t i = 0; i < initial_metadata_array_.count; i++) { + context_->client_metadata_.insert(std::make_pair( + grpc::string(initial_metadata_array_.metadata[i].key), + grpc::string( + initial_metadata_array_.metadata[i].value, + initial_metadata_array_.metadata[i].value + initial_metadata_array_.metadata[i].value_length))); + } + } + context_->call_ = call_; + context_->cq_ = call_cq_; + Call call(call_, server_, call_cq_, server_->max_message_size_); + if (*status && call_) { + context_->BeginCompletionOp(&call); + } + // just the pointers inside call are copied here + stream_->BindCall(&call); + delete this; + return true; +} + +Server::RegisteredAsyncRequest::RegisteredAsyncRequest(Server* server, ServerContext* context, + ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, void* tag) + : BaseAsyncRequest(server, context, stream, call_cq, tag) {} + + +void Server::RegisteredAsyncRequest::IssueRequest(void* registered_method, grpc_byte_buffer** payload, ServerCompletionQueue *notification_cq) { + grpc_server_request_registered_call( + server_->server_, registered_method, &call_, &context_->deadline_, &initial_metadata_array_, payload, call_cq_->cq(), notification_cq->cq(), this); +} + +Server::GenericAsyncRequest::GenericAsyncRequest(Server* server, GenericServerContext* context, + ServerAsyncStreamingInterface* stream, + CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, + void* tag) +: BaseAsyncRequest(server, context, stream, call_cq, tag) { + grpc_call_details_init(&call_details_); + GPR_ASSERT(notification_cq); + GPR_ASSERT(call_cq); + grpc_server_request_call(server->server_, &call_, &call_details_, &initial_metadata_array_, + call_cq->cq(), notification_cq->cq(), this); +} + +bool Server::GenericAsyncRequest::FinalizeResult(void** tag, bool* status) { + // TODO(yangg) remove the copy here. + static_cast<GenericServerContext*>(context_)->method_ = call_details_.method; + static_cast<GenericServerContext*>(context_)->host_ = call_details_.host; + gpr_free(call_details_.method); + gpr_free(call_details_.host); + return BaseAsyncRequest::FinalizeResult(tag, status); +} + void Server::ScheduleCallback() { { grpc::unique_lock<grpc::mutex> lock(mu_);