Skip to content
Snippets Groups Projects
Commit f9329217 authored by Yuchen Zeng's avatar Yuchen Zeng
Browse files

Support client streaming

parent 47f1f9e1
No related branches found
No related tags found
No related merge requests found
......@@ -37,8 +37,6 @@
#include <grpc++/channel.h>
#include <grpc++/client_context.h>
#include <grpc++/completion_queue.h>
#include <grpc++/generic/generic_stub.h>
#include <grpc++/support/byte_buffer.h>
#include <grpc/grpc.h>
#include <grpc/slice.h>
......@@ -50,49 +48,61 @@ namespace {
void* tag(int i) { return (void*)(intptr_t)i; }
} // namespace
enum CliCall::CallStatus : intptr_t { CREATE, PROCESS, FINISH };
Status CliCall::Call(std::shared_ptr<grpc::Channel> channel,
const grpc::string& method, const grpc::string& request,
grpc::string* response,
const OutgoingMetadataContainer& metadata,
IncomingMetadataContainer* server_initial_metadata,
IncomingMetadataContainer* server_trailing_metadata) {
std::unique_ptr<grpc::GenericStub> stub(new grpc::GenericStub(channel));
grpc::ClientContext ctx;
CliCall call(channel, method, metadata);
call.Write(request);
call.WritesDone();
call.Read(response, server_initial_metadata);
return call.Finish(server_trailing_metadata);
}
CliCall::CliCall(std::shared_ptr<grpc::Channel> channel,
const grpc::string& method,
const OutgoingMetadataContainer& metadata)
: stub_(new grpc::GenericStub(channel)) {
if (!metadata.empty()) {
for (OutgoingMetadataContainer::const_iterator iter = metadata.begin();
iter != metadata.end(); ++iter) {
ctx.AddMetadata(iter->first, iter->second);
ctx_.AddMetadata(iter->first, iter->second);
}
}
grpc::CompletionQueue cq;
std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call(
stub->Call(&ctx, method, &cq, tag(1)));
call_ = stub_->Call(&ctx_, method, &cq_, tag(1));
void* got_tag;
bool ok;
cq.Next(&got_tag, &ok);
cq_.Next(&got_tag, &ok);
GPR_ASSERT(ok);
}
void CliCall::Write(const grpc::string& request) {
void* got_tag;
bool ok;
grpc_slice s = grpc_slice_from_copied_string(request.c_str());
grpc::Slice req_slice(s, grpc::Slice::STEAL_REF);
grpc::ByteBuffer send_buffer(&req_slice, 1);
call->Write(send_buffer, tag(2));
cq.Next(&got_tag, &ok);
GPR_ASSERT(ok);
call->WritesDone(tag(3));
cq.Next(&got_tag, &ok);
call_->Write(send_buffer, tag(2));
cq_.Next(&got_tag, &ok);
GPR_ASSERT(ok);
grpc::ByteBuffer recv_buffer;
call->Read(&recv_buffer, tag(4));
cq.Next(&got_tag, &ok);
if (!ok) {
std::cout << "Failed to read response." << std::endl;
}
grpc::Status status;
call->Finish(&status, tag(5));
cq.Next(&got_tag, &ok);
GPR_ASSERT(ok);
if (status.ok()) {
void CliCall::Read(grpc::string* response,
IncomingMetadataContainer* server_initial_metadata) {
void* got_tag;
bool ok;
grpc::ByteBuffer recv_buffer;
call_->Read(&recv_buffer, tag(4));
cq_.Next(&got_tag, &ok);
if (!ok) {
fprintf(stderr, "Failed to read response.");
} else {
std::vector<grpc::Slice> slices;
(void)recv_buffer.Dump(&slices);
......@@ -101,10 +111,33 @@ Status CliCall::Call(std::shared_ptr<grpc::Channel> channel,
response->append(reinterpret_cast<const char*>(slices[i].begin()),
slices[i].size());
}
if (server_initial_metadata) {
*server_initial_metadata = ctx_.GetServerInitialMetadata();
}
}
}
void CliCall::WritesDone() {
void* got_tag;
bool ok;
call_->WritesDone(tag(3));
cq_.Next(&got_tag, &ok);
GPR_ASSERT(ok);
}
Status CliCall::Finish(IncomingMetadataContainer* server_trailing_metadata) {
void* got_tag;
bool ok;
grpc::Status status;
call_->Finish(&status, tag(5));
cq_.Next(&got_tag, &ok);
GPR_ASSERT(ok);
if (server_trailing_metadata) {
*server_trailing_metadata = ctx_.GetServerTrailingMetadata();
}
*server_initial_metadata = ctx.GetServerInitialMetadata();
*server_trailing_metadata = ctx.GetServerTrailingMetadata();
return status;
}
......
......@@ -37,10 +37,15 @@
#include <map>
#include <grpc++/channel.h>
#include <grpc++/completion_queue.h>
#include <grpc++/generic/generic_stub.h>
#include <grpc++/support/status.h>
#include <grpc++/support/string_ref.h>
namespace grpc {
class ClientContext;
namespace testing {
class CliCall final {
......@@ -48,12 +53,32 @@ class CliCall final {
typedef std::multimap<grpc::string, grpc::string> OutgoingMetadataContainer;
typedef std::multimap<grpc::string_ref, grpc::string_ref>
IncomingMetadataContainer;
CliCall(std::shared_ptr<grpc::Channel> channel, const grpc::string& method,
const OutgoingMetadataContainer& metadata);
static Status Call(std::shared_ptr<grpc::Channel> channel,
const grpc::string& method, const grpc::string& request,
grpc::string* response,
const OutgoingMetadataContainer& metadata,
IncomingMetadataContainer* server_initial_metadata,
IncomingMetadataContainer* server_trailing_metadata);
void Write(const grpc::string& request);
void WritesDone();
void Read(grpc::string* response,
IncomingMetadataContainer* server_initial_metadata);
Status Finish(IncomingMetadataContainer* server_trailing_metadata);
private:
enum CallStatus : intptr_t;
std::unique_ptr<grpc::GenericStub> stub_;
grpc::ClientContext ctx_;
std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_;
grpc::CompletionQueue cq_;
};
} // namespace testing
......
......@@ -419,6 +419,119 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
std::unique_ptr<grpc::testing::ProtoFileParser> parser;
grpc::string serialized_request_proto;
std::shared_ptr<grpc::Channel> channel =
FLAGS_remotedb
? grpc::CreateChannel(server_address, cred.GetCredentials())
: nullptr;
parser.reset(new grpc::testing::ProtoFileParser(channel, FLAGS_proto_path,
FLAGS_protofiles));
grpc::string formated_method_name =
parser->GetFormatedMethodName(method_name);
if (parser->HasError()) {
return false;
}
if (parser->IsStreaming(method_name, true /* is_request */)) {
fprintf(stderr, "streaming request\n");
std::istream* input_stream;
std::ifstream input_file;
if (argc == 3) {
request_text = argv[2];
if (!FLAGS_infile.empty()) {
fprintf(stderr, "warning: request given in argv, ignoring --infile\n");
}
}
// std::stringstream input_stream;
std::multimap<grpc::string, grpc::string> client_metadata;
ParseMetadataFlag(&client_metadata);
PrintMetadata(client_metadata, "Sending client initial metadata:");
CliCall call(channel, formated_method_name, client_metadata);
if (FLAGS_infile.empty()) {
if (isatty(STDIN_FILENO)) {
fprintf(stderr, "reading request message from stdin...\n");
}
input_stream = &std::cin;
// rdbuf = std::cin.rdbuf();
// input_stream.rdbuf(std::cin.rdbuf());
// input_stream << std::cin.rdbuf();
} else {
input_file.open(FLAGS_infile, std::ios::in | std::ios::binary);
// rdbuf = input_file.rdbuf();
// input_stream.rdbuf(input_file.rdbuf());
input_stream = &input_file;
// input_file.close();
}
// request_text = input_stream.str();
std::stringstream request_ss;
grpc::string line;
while (!input_stream->eof() && getline(*input_stream, line)) {
if (line.length() == 0) {
// request_text = request_ss.str();
if (FLAGS_binary_input) {
serialized_request_proto = request_ss.str();
} else {
serialized_request_proto = parser->GetSerializedProtoFromMethod(
method_name, request_ss.str(), true /* is_request */);
if (parser->HasError()) {
return false;
}
}
request_ss.str(grpc::string());
request_ss.clear();
grpc::string response_text = parser->GetTextFormatFromMethod(
method_name, serialized_request_proto, true /* is_request */);
call.Write(serialized_request_proto);
fprintf(stderr, "%s", response_text.c_str());
} else {
request_ss << line << ' ';
}
}
if (input_file.is_open()) {
input_file.close();
}
call.WritesDone();
grpc::string serialized_response_proto;
std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata,
server_trailing_metadata;
call.Read(&serialized_response_proto, &server_initial_metadata);
Status status = call.Finish(&server_trailing_metadata);
PrintMetadata(server_initial_metadata,
"Received initial metadata from server:");
PrintMetadata(server_trailing_metadata,
"Received trailing metadata from server:");
if (status.ok()) {
fprintf(stderr, "Stream RPC succeeded with OK status\n");
if (FLAGS_binary_output) {
output_ss << serialized_response_proto;
} else {
grpc::string response_text = parser->GetTextFormatFromMethod(
method_name, serialized_response_proto, false /* is_request */);
if (parser->HasError()) {
return false;
}
output_ss << "Response: \n " << response_text << std::endl;
}
} else {
fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
status.error_code(), status.error_message().c_str());
}
} else { // parser->IsStreaming(method_name, true /* is_request */)
if (argc == 3) {
request_text = argv[2];
if (!FLAGS_infile.empty()) {
......@@ -439,22 +552,9 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
request_text = input_stream.str();
}
std::shared_ptr<grpc::Channel> channel =
grpc::CreateChannel(server_address, cred.GetCredentials());
if (!FLAGS_binary_input || !FLAGS_binary_output) {
parser.reset(
new grpc::testing::ProtoFileParser(FLAGS_remotedb ? channel : nullptr,
FLAGS_proto_path, FLAGS_protofiles));
if (parser->HasError()) {
return false;
}
}
if (FLAGS_binary_input) {
serialized_request_proto = request_text;
formatted_method_name = method_name;
} else {
formatted_method_name = parser->GetFormattedMethodName(method_name);
serialized_request_proto = parser->GetSerializedProtoFromMethod(
method_name, request_text, true /* is_request */);
if (parser->HasError()) {
......@@ -470,7 +570,7 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
ParseMetadataFlag(&client_metadata);
PrintMetadata(client_metadata, "Sending client initial metadata:");
grpc::Status status = grpc::testing::CliCall::Call(
channel, formatted_method_name, serialized_request_proto,
channel, formated_method_name, serialized_request_proto,
&serialized_response_proto, client_metadata, &server_initial_metadata,
&server_trailing_metadata);
PrintMetadata(server_initial_metadata,
......@@ -493,6 +593,7 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
status.error_code(), status.error_message().c_str());
}
}
return callback(output_ss.str());
}
......
......@@ -144,12 +144,18 @@ ProtoFileParser::~ProtoFileParser() {}
grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) {
has_error_ = false;
if (known_methods_.find(method) != known_methods_.end()) {
return known_methods_[method];
}
const protobuf::MethodDescriptor* method_descriptor = nullptr;
for (auto it = service_desc_list_.begin(); it != service_desc_list_.end();
it++) {
const auto* service_desc = *it;
for (int j = 0; j < service_desc->method_count(); j++) {
const auto* method_desc = service_desc->method(j);
fprintf(stderr, "%s\n", method_desc->full_name().c_str());
if (MethodNameMatch(method_desc->full_name(), method)) {
if (method_descriptor) {
std::ostringstream error_stream;
......@@ -169,6 +175,8 @@ grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) {
return "";
}
known_methods_[method] = method_descriptor->full_name();
return method_descriptor->full_name();
}
......@@ -205,6 +213,25 @@ grpc::string ProtoFileParser::GetMessageTypeFromMethod(
: method_desc->output_type()->full_name();
}
bool ProtoFileParser::IsStreaming(const grpc::string& method, bool is_request) {
has_error_ = false;
grpc::string full_method_name = GetFullMethodName(method);
if (has_error_) {
return false;
}
const protobuf::MethodDescriptor* method_desc =
desc_pool_->FindMethodByName(full_method_name);
if (!method_desc) {
LogError("Method not found");
return false;
}
return is_request ? method_desc->client_streaming()
: method_desc->server_streaming();
}
grpc::string ProtoFileParser::GetSerializedProtoFromMethod(
const grpc::string& method, const grpc::string& text_format_proto,
bool is_request) {
......
......@@ -84,6 +84,8 @@ class ProtoFileParser {
const grpc::string& message_type_name,
const grpc::string& serialized_proto);
bool IsStreaming(const grpc::string& method, bool is_request);
bool HasError() const { return has_error_; }
void LogError(const grpc::string& error_msg);
......@@ -104,6 +106,7 @@ class ProtoFileParser {
std::unique_ptr<protobuf::DynamicMessageFactory> dynamic_factory_;
std::unique_ptr<grpc::protobuf::Message> request_prototype_;
std::unique_ptr<grpc::protobuf::Message> response_prototype_;
std::unordered_map<grpc::string, grpc::string> known_methods_;
std::vector<const protobuf::ServiceDescriptor*> service_desc_list_;
};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment