diff --git a/include/grpc++/channel_interface.h b/include/grpc++/channel_interface.h index 10fb9538bcd5f16f3d4b839980330f2c8a5cec0d..4176cded7b6baf7448e9ca6a857eba88c7905368 100644 --- a/include/grpc++/channel_interface.h +++ b/include/grpc++/channel_interface.h @@ -36,6 +36,7 @@ #include <memory> +#include <grpc/grpc.h> #include <grpc++/status.h> #include <grpc++/impl/call.h> @@ -47,7 +48,6 @@ class CallOpBuffer; class ClientContext; class CompletionQueue; class RpcMethod; -class CallInterface; class ChannelInterface : public CallHook, public std::enable_shared_from_this<ChannelInterface> { @@ -57,6 +57,34 @@ class ChannelInterface : public CallHook, virtual void* RegisterMethod(const char* method_name) = 0; virtual Call CreateCall(const RpcMethod& method, ClientContext* context, CompletionQueue* cq) = 0; + + // Get the current channel state. If the channel is in IDLE and try_to_connect + // is set to true, try to connect. + virtual grpc_connectivity_state GetState(bool try_to_connect) = 0; + + // Return the tag on cq when the channel state is changed or deadline expires. + // GetState needs to called to get the current state. + template <typename T> + void NotifyOnStateChange(grpc_connectivity_state last_observed, T deadline, + CompletionQueue* cq, void* tag) { + TimePoint<T> deadline_tp(deadline); + NotifyOnStateChangeImpl(last_observed, deadline_tp.raw_time(), cq, tag); + } + + // Blocking wait for channel state change or deadline expiration. + // GetState needs to called to get the current state. + template <typename T> + bool WaitForStateChange(grpc_connectivity_state last_observed, T deadline) { + TimePoint<T> deadline_tp(deadline); + return WaitForStateChangeImpl(last_observed, deadline_tp.raw_time()); + } + + private: + virtual void NotifyOnStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline, + CompletionQueue* cq, void* tag) = 0; + virtual bool WaitForStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline) = 0; }; } // namespace grpc diff --git a/src/cpp/client/channel.cc b/src/cpp/client/channel.cc index 5f54e7fcc17cec716733eec7a954058e977661cb..af7366eb01cab2ec974a92a1c2a8c3d94e4cf30b 100644 --- a/src/cpp/client/channel.cc +++ b/src/cpp/client/channel.cc @@ -48,6 +48,7 @@ #include <grpc++/impl/call.h> #include <grpc++/impl/rpc_method.h> #include <grpc++/status.h> +#include <grpc++/time.h> namespace grpc { @@ -95,4 +96,43 @@ void* Channel::RegisterMethod(const char* method) { host_.empty() ? NULL : host_.c_str()); } +grpc_connectivity_state Channel::GetState(bool try_to_connect) { + return grpc_channel_check_connectivity_state(c_channel_, try_to_connect); +} + +namespace { +class TagSaver GRPC_FINAL : public CompletionQueueTag { + public: + explicit TagSaver(void* tag) : tag_(tag) {} + ~TagSaver() GRPC_OVERRIDE {} + bool FinalizeResult(void** tag, bool* status) GRPC_OVERRIDE { + *tag = tag_; + delete this; + return true; + } + private: + void* tag_; +}; + +} // namespace + +void Channel::NotifyOnStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline, + CompletionQueue* cq, void* tag) { + TagSaver* tag_saver = new TagSaver(tag); + grpc_channel_watch_connectivity_state(c_channel_, last_observed, deadline, + cq->cq(), tag_saver); +} + +bool Channel::WaitForStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline) { + CompletionQueue cq; + bool ok = false; + void* tag = NULL; + NotifyOnStateChangeImpl(last_observed, deadline, &cq, NULL); + cq.Next(&tag, &ok); + GPR_ASSERT(tag == NULL); + return ok; +} + } // namespace grpc diff --git a/src/cpp/client/channel.h b/src/cpp/client/channel.h index 8660146856cc951e0c056d69b637bbb4b5058bc0..cb8e8d98d22549e868f3ce42e1e4a2c008f0ce55 100644 --- a/src/cpp/client/channel.h +++ b/src/cpp/client/channel.h @@ -56,13 +56,22 @@ class Channel GRPC_FINAL : public GrpcLibrary, public ChannelInterface { Channel(const grpc::string& host, grpc_channel* c_channel); ~Channel() GRPC_OVERRIDE; - virtual void* RegisterMethod(const char* method) GRPC_OVERRIDE; - virtual Call CreateCall(const RpcMethod& method, ClientContext* context, + void* RegisterMethod(const char* method) GRPC_OVERRIDE; + Call CreateCall(const RpcMethod& method, ClientContext* context, CompletionQueue* cq) GRPC_OVERRIDE; - virtual void PerformOpsOnCall(CallOpSetInterface* ops, + void PerformOpsOnCall(CallOpSetInterface* ops, Call* call) GRPC_OVERRIDE; + grpc_connectivity_state GetState(bool try_to_connect) GRPC_OVERRIDE; + private: + void NotifyOnStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline, CompletionQueue* cq, + void* tag) GRPC_OVERRIDE; + + bool WaitForStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline) GRPC_OVERRIDE; + const grpc::string host_; grpc_channel* const c_channel_; // owned }; diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc index 24d417d9e65469a1eec57359c16fe479bb26874e..5f0749daa5c1b1d7e07181a513ff98af8da1abf4 100644 --- a/test/cpp/end2end/end2end_test.cc +++ b/test/cpp/end2end/end2end_test.cc @@ -869,7 +869,8 @@ TEST_P(End2endTest, HugeResponse) { } namespace { -void ReaderThreadFunc(ClientReaderWriter<EchoRequest, EchoResponse>* stream, gpr_event *ev) { +void ReaderThreadFunc(ClientReaderWriter<EchoRequest, EchoResponse>* stream, + gpr_event *ev) { EchoResponse resp; gpr_event_set(ev, (void*)1); while (stream->Read(&resp)) { @@ -908,6 +909,27 @@ TEST_P(End2endTest, Peer) { EXPECT_TRUE(CheckIsLocalhost(context.peer())); } +TEST_F(End2endTest, ChannelState) { + ResetStub(false); + // Start IDLE + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(false)); + + // Did not ask to connect, no state change. + CompletionQueue cq; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::milliseconds(10); + channel_->NotifyOnStateChange(GRPC_CHANNEL_IDLE, deadline, &cq, NULL); + void* tag; + bool ok = true; + cq.Next(&tag, &ok); + EXPECT_FALSE(ok); + + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(true)); + EXPECT_TRUE(channel_->WaitForStateChange( + GRPC_CHANNEL_IDLE, gpr_inf_future(GPR_CLOCK_REALTIME))); + EXPECT_EQ(GRPC_CHANNEL_CONNECTING, channel_->GetState(false)); +} + INSTANTIATE_TEST_CASE_P(End2end, End2endTest, ::testing::Values(false, true)); } // namespace testing