Skip to content
Snippets Groups Projects
Commit a831651a authored by Vijay Pai's avatar Vijay Pai
Browse files

Unify and make consistent the per-thread shutdown process

parent f782465f
No related branches found
No related tags found
No related merge requests found
...@@ -174,6 +174,7 @@ class AsyncClient : public ClientImpl<StubType, RequestType> { ...@@ -174,6 +174,7 @@ class AsyncClient : public ClientImpl<StubType, RequestType> {
for (int i = 0; i < num_async_threads_; i++) { for (int i = 0; i < num_async_threads_; i++) {
cli_cqs_.emplace_back(new CompletionQueue); cli_cqs_.emplace_back(new CompletionQueue);
next_issuers_.emplace_back(NextIssuer(i)); next_issuers_.emplace_back(NextIssuer(i));
shutdown_state_.emplace_back(new PerThreadShutdownState());
} }
using namespace std::placeholders; using namespace std::placeholders;
...@@ -189,7 +190,21 @@ class AsyncClient : public ClientImpl<StubType, RequestType> { ...@@ -189,7 +190,21 @@ class AsyncClient : public ClientImpl<StubType, RequestType> {
} }
} }
virtual ~AsyncClient() { virtual ~AsyncClient() {
FinalShutdownCQs(); for (auto ss = shutdown_state_.begin(); ss != shutdown_state_.end(); ++ss) {
std::lock_guard<std::mutex> lock((*ss)->mutex);
(*ss)->shutdown = true;
}
for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) {
(*cq)->Shutdown();
}
this->EndThreads(); // Need "this->" for resolution
for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) {
void* got_tag;
bool ok;
while ((*cq)->Next(&got_tag, &ok)) {
delete ClientRpcContext::detag(got_tag);
}
}
} }
bool ThreadFunc(HistogramEntry* entry, bool ThreadFunc(HistogramEntry* entry,
...@@ -200,7 +215,12 @@ class AsyncClient : public ClientImpl<StubType, RequestType> { ...@@ -200,7 +215,12 @@ class AsyncClient : public ClientImpl<StubType, RequestType> {
if (cli_cqs_[thread_idx]->Next(&got_tag, &ok)) { if (cli_cqs_[thread_idx]->Next(&got_tag, &ok)) {
// Got a regular event, so process it // Got a regular event, so process it
ClientRpcContext* ctx = ClientRpcContext::detag(got_tag); ClientRpcContext* ctx = ClientRpcContext::detag(got_tag);
if (!ctx->RunNextState(ok, entry)) { // Proceed while holding a lock to make sure that
// this thread isn't supposed to shut down
std::lock_guard<std::mutex> l(shutdown_state_[thread_idx]->mutex);
if (shutdown_state_[thread_idx]->shutdown) {
return true;
} else if (!ctx->RunNextState(ok, entry)) {
// The RPC and callback are done, so clone the ctx // The RPC and callback are done, so clone the ctx
// and kickstart the new one // and kickstart the new one
auto clone = ctx->StartNewClone(); auto clone = ctx->StartNewClone();
...@@ -217,22 +237,13 @@ class AsyncClient : public ClientImpl<StubType, RequestType> { ...@@ -217,22 +237,13 @@ class AsyncClient : public ClientImpl<StubType, RequestType> {
protected: protected:
const int num_async_threads_; const int num_async_threads_;
void ShutdownCQs() {
for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) {
(*cq)->Shutdown();
}
}
void FinalShutdownCQs() {
for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) {
void* got_tag;
bool ok;
while ((*cq)->Next(&got_tag, &ok)) {
delete ClientRpcContext::detag(got_tag);
}
}
}
private: private:
struct PerThreadShutdownState {
mutable std::mutex mutex;
bool shutdown;
PerThreadShutdownState() : shutdown(false) {}
};
int NumThreads(const ClientConfig& config) { int NumThreads(const ClientConfig& config) {
int num_threads = config.async_client_threads(); int num_threads = config.async_client_threads();
if (num_threads <= 0) { // Use dynamic sizing if (num_threads <= 0) { // Use dynamic sizing
...@@ -241,9 +252,9 @@ class AsyncClient : public ClientImpl<StubType, RequestType> { ...@@ -241,9 +252,9 @@ class AsyncClient : public ClientImpl<StubType, RequestType> {
} }
return num_threads; return num_threads;
} }
std::vector<std::unique_ptr<CompletionQueue>> cli_cqs_; std::vector<std::unique_ptr<CompletionQueue>> cli_cqs_;
std::vector<std::function<gpr_timespec()>> next_issuers_; std::vector<std::function<gpr_timespec()>> next_issuers_;
std::vector<std::unique_ptr<PerThreadShutdownState>> shutdown_state_;
}; };
static std::unique_ptr<BenchmarkService::Stub> BenchmarkStubCreator( static std::unique_ptr<BenchmarkService::Stub> BenchmarkStubCreator(
...@@ -259,10 +270,7 @@ class AsyncUnaryClient GRPC_FINAL ...@@ -259,10 +270,7 @@ class AsyncUnaryClient GRPC_FINAL
config, SetupCtx, BenchmarkStubCreator) { config, SetupCtx, BenchmarkStubCreator) {
StartThreads(num_async_threads_); StartThreads(num_async_threads_);
} }
~AsyncUnaryClient() GRPC_OVERRIDE { ~AsyncUnaryClient() GRPC_OVERRIDE {}
ShutdownCQs();
EndThreads();
}
private: private:
static void CheckDone(grpc::Status s, SimpleResponse* response) {} static void CheckDone(grpc::Status s, SimpleResponse* response) {}
...@@ -391,10 +399,7 @@ class AsyncStreamingClient GRPC_FINAL ...@@ -391,10 +399,7 @@ class AsyncStreamingClient GRPC_FINAL
StartThreads(num_async_threads_); StartThreads(num_async_threads_);
} }
~AsyncStreamingClient() GRPC_OVERRIDE { ~AsyncStreamingClient() GRPC_OVERRIDE {}
ShutdownCQs();
EndThreads();
}
private: private:
static void CheckDone(grpc::Status s, SimpleResponse* response) {} static void CheckDone(grpc::Status s, SimpleResponse* response) {}
...@@ -530,10 +535,7 @@ class GenericAsyncStreamingClient GRPC_FINAL ...@@ -530,10 +535,7 @@ class GenericAsyncStreamingClient GRPC_FINAL
StartThreads(num_async_threads_); StartThreads(num_async_threads_);
} }
~GenericAsyncStreamingClient() GRPC_OVERRIDE { ~GenericAsyncStreamingClient() GRPC_OVERRIDE {}
ShutdownCQs();
EndThreads();
}
private: private:
static void CheckDone(grpc::Status s, ByteBuffer* response) {} static void CheckDone(grpc::Status s, ByteBuffer* response) {}
......
...@@ -123,21 +123,22 @@ class AsyncQpsServerTest : public Server { ...@@ -123,21 +123,22 @@ class AsyncQpsServerTest : public Server {
for (int i = 0; i < num_threads; i++) { for (int i = 0; i < num_threads; i++) {
shutdown_state_.emplace_back(new PerThreadShutdownState()); shutdown_state_.emplace_back(new PerThreadShutdownState());
}
for (int i = 0; i < num_threads; i++) {
threads_.emplace_back(&AsyncQpsServerTest::ThreadFunc, this, i); threads_.emplace_back(&AsyncQpsServerTest::ThreadFunc, this, i);
} }
} }
~AsyncQpsServerTest() { ~AsyncQpsServerTest() {
for (auto ss = shutdown_state_.begin(); ss != shutdown_state_.end(); ++ss) { for (auto ss = shutdown_state_.begin(); ss != shutdown_state_.end(); ++ss) {
(*ss)->set_shutdown(); std::lock_guard<std::mutex> lock((*ss)->mutex);
(*ss)->shutdown = true;
} }
server_->Shutdown(); server_->Shutdown();
for (auto cq = srv_cqs_.begin(); cq != srv_cqs_.end(); ++cq) {
(*cq)->Shutdown();
}
for (auto thr = threads_.begin(); thr != threads_.end(); thr++) { for (auto thr = threads_.begin(); thr != threads_.end(); thr++) {
thr->join(); thr->join();
} }
for (auto cq = srv_cqs_.begin(); cq != srv_cqs_.end(); ++cq) { for (auto cq = srv_cqs_.begin(); cq != srv_cqs_.end(); ++cq) {
(*cq)->Shutdown();
bool ok; bool ok;
void *got_tag; void *got_tag;
while ((*cq)->Next(&got_tag, &ok)) while ((*cq)->Next(&got_tag, &ok))
...@@ -150,21 +151,21 @@ class AsyncQpsServerTest : public Server { ...@@ -150,21 +151,21 @@ class AsyncQpsServerTest : public Server {
} }
private: private:
void ThreadFunc(int rank) { void ThreadFunc(int thread_idx) {
// Wait until work is available or we are shutting down // Wait until work is available or we are shutting down
bool ok; bool ok;
void *got_tag; void *got_tag;
while (srv_cqs_[rank]->Next(&got_tag, &ok)) { while (srv_cqs_[thread_idx]->Next(&got_tag, &ok)) {
ServerRpcContext *ctx = detag(got_tag); ServerRpcContext *ctx = detag(got_tag);
// The tag is a pointer to an RPC context to invoke // The tag is a pointer to an RPC context to invoke
// Proceed while holding a lock to make sure that
// this thread isn't supposed to shut down
std::lock_guard<std::mutex> l(shutdown_state_[thread_idx]->mutex);
if (shutdown_state_[thread_idx]->shutdown) { return; }
const bool still_going = ctx->RunNextState(ok); const bool still_going = ctx->RunNextState(ok);
if (!shutdown_state_[rank]->shutdown()) { // if this RPC context is done, refresh it
// this RPC context is done, so refresh it if (!still_going) {
if (!still_going) { ctx->Reset();
ctx->Reset();
}
} else {
return;
} }
} }
return; return;
...@@ -333,24 +334,12 @@ class AsyncQpsServerTest : public Server { ...@@ -333,24 +334,12 @@ class AsyncQpsServerTest : public Server {
ServiceType async_service_; ServiceType async_service_;
std::forward_list<ServerRpcContext *> contexts_; std::forward_list<ServerRpcContext *> contexts_;
class PerThreadShutdownState { struct PerThreadShutdownState {
public: mutable std::mutex mutex;
PerThreadShutdownState() : shutdown_(false) {} bool shutdown;
PerThreadShutdownState() : shutdown(false) {}
bool shutdown() const {
std::lock_guard<std::mutex> lock(mutex_);
return shutdown_;
}
void set_shutdown() {
std::lock_guard<std::mutex> lock(mutex_);
shutdown_ = true;
}
private:
mutable std::mutex mutex_;
bool shutdown_;
}; };
std::vector<std::unique_ptr<PerThreadShutdownState>> shutdown_state_; std::vector<std::unique_ptr<PerThreadShutdownState>> shutdown_state_;
}; };
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment