From d59ad7ef393c7624c7035a09b488f630cbd96730 Mon Sep 17 00:00:00 2001
From: yang-g <yangg@google.com>
Date: Wed, 10 Feb 2016 12:42:53 -0800
Subject: [PATCH] Provide explicit API for user to set user agent string prefix

---
 include/grpc++/support/channel_arguments.h |   9 +-
 src/cpp/client/create_channel.cc           |   8 +-
 src/cpp/common/channel_arguments.cc        |  33 +++-
 test/cpp/common/channel_arguments_test.cc  | 183 +++++++++++++--------
 test/cpp/end2end/end2end_test.cc           |  23 +++
 5 files changed, 174 insertions(+), 82 deletions(-)

diff --git a/include/grpc++/support/channel_arguments.h b/include/grpc++/support/channel_arguments.h
index a2960a7ecc..72f52657cd 100644
--- a/include/grpc++/support/channel_arguments.h
+++ b/include/grpc++/support/channel_arguments.h
@@ -51,7 +51,7 @@ class ChannelArgumentsTest;
 /// concrete setters are provided.
 class ChannelArguments {
  public:
-  ChannelArguments() {}
+  ChannelArguments();
   ~ChannelArguments() {}
 
   ChannelArguments(const ChannelArguments& other);
@@ -62,8 +62,8 @@ class ChannelArguments {
 
   void Swap(ChannelArguments& other);
 
-  /// Populates this instance with the arguments from \a channel_args. Does not
-  /// take ownership of \a channel_args.
+  /// Dump arguments in this instance to \a channel_args. Does not take
+  /// ownership of \a channel_args.
   ///
   /// Note that the underlying arguments are shared. Changes made to either \a
   /// channel_args or this instance would be reflected on both.
@@ -77,6 +77,9 @@ class ChannelArguments {
   /// Set the compression algorithm for the channel.
   void SetCompressionAlgorithm(grpc_compression_algorithm algorithm);
 
+  /// The given string will be sent at the front of the user agent string.
+  void SetUserAgentPrefix(const grpc::string& user_agent_prefix);
+
   // Generic channel argument setters. Only for advanced use cases.
   /// Set an integer argument \a value under \a key.
   void SetInt(const grpc::string& key, int value);
diff --git a/src/cpp/client/create_channel.cc b/src/cpp/client/create_channel.cc
index fdaa28ffef..76a1b31e2f 100644
--- a/src/cpp/client/create_channel.cc
+++ b/src/cpp/client/create_channel.cc
@@ -32,7 +32,6 @@
  */
 
 #include <memory>
-#include <sstream>
 
 #include <grpc++/channel.h>
 #include <grpc++/create_channel.h>
@@ -56,13 +55,8 @@ std::shared_ptr<Channel> CreateCustomChannel(
     const ChannelArguments& args) {
   internal::GrpcLibrary
       init_lib;  // We need to call init in case of a bad creds.
-  ChannelArguments cp_args = args;
-  std::ostringstream user_agent_prefix;
-  user_agent_prefix << "grpc-c++/" << grpc_version_string();
-  cp_args.SetString(GRPC_ARG_PRIMARY_USER_AGENT_STRING,
-                    user_agent_prefix.str());
   return creds
-             ? creds->CreateChannel(target, cp_args)
+             ? creds->CreateChannel(target, args)
              : CreateChannelInternal("", grpc_lame_client_channel_create(
                                              NULL, GRPC_STATUS_INVALID_ARGUMENT,
                                              "Invalid credentials."));
diff --git a/src/cpp/common/channel_arguments.cc b/src/cpp/common/channel_arguments.cc
index 90cd5136af..e23c964797 100644
--- a/src/cpp/common/channel_arguments.cc
+++ b/src/cpp/common/channel_arguments.cc
@@ -30,14 +30,23 @@
  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  *
  */
-
 #include <grpc++/support/channel_arguments.h>
 
+#include <sstream>
+
+#include <grpc/impl/codegen/grpc_types.h>
 #include <grpc/support/log.h>
 #include "src/core/channel/channel_args.h"
 
 namespace grpc {
 
+ChannelArguments::ChannelArguments() {
+  std::ostringstream user_agent_prefix;
+  user_agent_prefix << "grpc-c++/" << grpc_version_string();
+  // This will be ignored if used on the server side.
+  SetString(GRPC_ARG_PRIMARY_USER_AGENT_STRING, user_agent_prefix.str());
+}
+
 ChannelArguments::ChannelArguments(const ChannelArguments& other)
     : strings_(other.strings_) {
   args_.reserve(other.args_.size());
@@ -81,6 +90,28 @@ void ChannelArguments::SetCompressionAlgorithm(
   SetInt(GRPC_COMPRESSION_ALGORITHM_ARG, algorithm);
 }
 
+// Note: a second call to this will add in front the result of the first call.
+void ChannelArguments::SetUserAgentPrefix(
+    const grpc::string& user_agent_prefix) {
+  if (user_agent_prefix.empty()) {
+    return;
+  }
+  bool replaced = false;
+  for (auto it = args_.begin(); it != args_.end(); ++it) {
+    const grpc_arg& arg = *it;
+    if (arg.type == GRPC_ARG_STRING &&
+        grpc::string(arg.key) == GRPC_ARG_PRIMARY_USER_AGENT_STRING) {
+      strings_.push_back(user_agent_prefix + " " + arg.value.string);
+      it->value.string = const_cast<char*>(strings_.back().c_str());
+      replaced = true;
+      break;
+    }
+  }
+  if (!replaced) {
+    SetString(GRPC_ARG_PRIMARY_USER_AGENT_STRING, user_agent_prefix);
+  }
+}
+
 void ChannelArguments::SetInt(const grpc::string& key, int value) {
   grpc_arg arg;
   arg.type = GRPC_ARG_INTEGER;
diff --git a/test/cpp/common/channel_arguments_test.cc b/test/cpp/common/channel_arguments_test.cc
index e010d375cf..8a6e9b2222 100644
--- a/test/cpp/common/channel_arguments_test.cc
+++ b/test/cpp/common/channel_arguments_test.cc
@@ -45,90 +45,131 @@ class ChannelArgumentsTest : public ::testing::Test {
                       grpc_channel_args* args) {
     channel_args.SetChannelArgs(args);
   }
+
+  grpc::string GetDefaultUserAgentPrefix() {
+    std::ostringstream user_agent_prefix;
+    user_agent_prefix << "grpc-c++/" << grpc_version_string();
+    return user_agent_prefix.str();
+  }
+
+  void VerifyDefaultChannelArgs() {
+    grpc_channel_args args;
+    SetChannelArgs(channel_args_, &args);
+    EXPECT_EQ(static_cast<size_t>(1), args.num_args);
+    EXPECT_STREQ(GRPC_ARG_PRIMARY_USER_AGENT_STRING, args.args[0].key);
+    EXPECT_EQ(GetDefaultUserAgentPrefix(),
+              grpc::string(args.args[0].value.string));
+  }
+
+  bool HasArg(grpc_arg expected_arg) {
+    grpc_channel_args args;
+    SetChannelArgs(channel_args_, &args);
+    for (size_t i = 0; i < args.num_args; i++) {
+      const grpc_arg& arg = args.args[i];
+      if (arg.type == expected_arg.type &&
+          grpc::string(arg.key) == expected_arg.key) {
+        if (arg.type == GRPC_ARG_INTEGER) {
+          return arg.value.integer == expected_arg.value.integer;
+        } else if (arg.type == GRPC_ARG_STRING) {
+          return grpc::string(arg.value.string) == expected_arg.value.string;
+        } else if (arg.type == GRPC_ARG_POINTER) {
+          return arg.value.pointer.p == expected_arg.value.pointer.p &&
+                 arg.value.pointer.copy == expected_arg.value.pointer.copy &&
+                 arg.value.pointer.destroy ==
+                     expected_arg.value.pointer.destroy;
+        }
+      }
+    }
+    return false;
+  }
+  ChannelArguments channel_args_;
 };
 
 TEST_F(ChannelArgumentsTest, SetInt) {
-  grpc_channel_args args;
-  ChannelArguments channel_args;
-  // Empty arguments.
-  SetChannelArgs(channel_args, &args);
-  EXPECT_EQ(static_cast<size_t>(0), args.num_args);
-
-  grpc::string key("key0");
-  channel_args.SetInt(key, 0);
+  VerifyDefaultChannelArgs();
+  grpc::string key0("key0");
+  grpc_arg arg0;
+  arg0.type = GRPC_ARG_INTEGER;
+  arg0.key = const_cast<char*>(key0.c_str());
+  arg0.value.integer = 0;
+  grpc::string key1("key1");
+  grpc_arg arg1;
+  arg1.type = GRPC_ARG_INTEGER;
+  arg1.key = const_cast<char*>(key1.c_str());
+  arg1.value.integer = 1;
+
+  grpc::string arg_key0(key0);
+  channel_args_.SetInt(arg_key0, arg0.value.integer);
   // Clear key early to make sure channel_args takes a copy
-  key = "";
-  SetChannelArgs(channel_args, &args);
-  EXPECT_EQ(static_cast<size_t>(1), args.num_args);
-  EXPECT_EQ(GRPC_ARG_INTEGER, args.args[0].type);
-  EXPECT_STREQ("key0", args.args[0].key);
-  EXPECT_EQ(0, args.args[0].value.integer);
-
-  key = "key1";
-  channel_args.SetInt(key, 1);
-  key = "";
-  SetChannelArgs(channel_args, &args);
-  EXPECT_EQ(static_cast<size_t>(2), args.num_args);
-  // We do not enforce order on the arguments.
-  for (size_t i = 0; i < args.num_args; i++) {
-    EXPECT_EQ(GRPC_ARG_INTEGER, args.args[i].type);
-    if (grpc::string(args.args[i].key) == "key0") {
-      EXPECT_EQ(0, args.args[i].value.integer);
-    } else if (grpc::string(args.args[i].key) == "key1") {
-      EXPECT_EQ(1, args.args[i].value.integer);
-    }
-  }
+  arg_key0.clear();
+  EXPECT_TRUE(HasArg(arg0));
+
+  grpc::string arg_key1(key1);
+  channel_args_.SetInt(arg_key1, arg1.value.integer);
+  arg_key1.clear();
+  EXPECT_TRUE(HasArg(arg0));
+  EXPECT_TRUE(HasArg(arg1));
 }
 
 TEST_F(ChannelArgumentsTest, SetString) {
-  grpc_channel_args args;
-  ChannelArguments channel_args;
-  // Empty arguments.
-  SetChannelArgs(channel_args, &args);
-  EXPECT_EQ(static_cast<size_t>(0), args.num_args);
-
-  grpc::string key("key0");
-  grpc::string val("val0");
-  channel_args.SetString(key, val);
+  VerifyDefaultChannelArgs();
+  grpc::string key0("key0");
+  grpc::string val0("val0");
+  grpc_arg arg0;
+  arg0.type = GRPC_ARG_STRING;
+  arg0.key = const_cast<char*>(key0.c_str());
+  arg0.value.string = const_cast<char*>(val0.c_str());
+  grpc::string key1("key1");
+  grpc::string val1("val1");
+  grpc_arg arg1;
+  arg1.type = GRPC_ARG_STRING;
+  arg1.key = const_cast<char*>(key1.c_str());
+  arg1.value.string = const_cast<char*>(val1.c_str());
+
+  grpc::string key(key0);
+  grpc::string val(val0);
+  channel_args_.SetString(key, val);
   // Clear key/val early to make sure channel_args takes a copy
   key = "";
   val = "";
-  SetChannelArgs(channel_args, &args);
-  EXPECT_EQ(static_cast<size_t>(1), args.num_args);
-  EXPECT_EQ(GRPC_ARG_STRING, args.args[0].type);
-  EXPECT_STREQ("key0", args.args[0].key);
-  EXPECT_STREQ("val0", args.args[0].value.string);
-
-  key = "key1";
-  val = "val1";
-  channel_args.SetString(key, val);
-  SetChannelArgs(channel_args, &args);
-  EXPECT_EQ(static_cast<size_t>(2), args.num_args);
-  // We do not enforce order on the arguments.
-  for (size_t i = 0; i < args.num_args; i++) {
-    EXPECT_EQ(GRPC_ARG_STRING, args.args[i].type);
-    if (grpc::string(args.args[i].key) == "key0") {
-      EXPECT_STREQ("val0", args.args[i].value.string);
-    } else if (grpc::string(args.args[i].key) == "key1") {
-      EXPECT_STREQ("val1", args.args[i].value.string);
-    }
-  }
+  EXPECT_TRUE(HasArg(arg0));
+
+  key = key1;
+  val = val1;
+  channel_args_.SetString(key, val);
+  // Clear key/val early to make sure channel_args takes a copy
+  key = "";
+  val = "";
+  EXPECT_TRUE(HasArg(arg0));
+  EXPECT_TRUE(HasArg(arg1));
 }
 
 TEST_F(ChannelArgumentsTest, SetPointer) {
-  grpc_channel_args args;
-  ChannelArguments channel_args;
-  // Empty arguments.
-  SetChannelArgs(channel_args, &args);
-  EXPECT_EQ(static_cast<size_t>(0), args.num_args);
-
-  grpc::string key("key0");
-  channel_args.SetPointer(key, &key);
-  SetChannelArgs(channel_args, &args);
-  EXPECT_EQ(static_cast<size_t>(1), args.num_args);
-  EXPECT_EQ(GRPC_ARG_POINTER, args.args[0].type);
-  EXPECT_STREQ("key0", args.args[0].key);
-  EXPECT_EQ(&key, args.args[0].value.pointer.p);
+  VerifyDefaultChannelArgs();
+  grpc::string key0("key0");
+  grpc_arg arg0;
+  arg0.type = GRPC_ARG_POINTER;
+  arg0.key = const_cast<char*>(key0.c_str());
+  arg0.value.pointer.p = &key0;
+  arg0.value.pointer.copy = nullptr;
+  arg0.value.pointer.destroy = nullptr;
+
+  grpc::string key(key0);
+  channel_args_.SetPointer(key, arg0.value.pointer.p);
+  EXPECT_TRUE(HasArg(arg0));
+}
+
+TEST_F(ChannelArgumentsTest, SetUserAgentPrefix) {
+  VerifyDefaultChannelArgs();
+  grpc::string prefix("prefix");
+  grpc::string whole_prefix = prefix + " " + GetDefaultUserAgentPrefix();
+  grpc_arg arg0;
+  arg0.type = GRPC_ARG_STRING;
+  arg0.key = const_cast<char*>(GRPC_ARG_PRIMARY_USER_AGENT_STRING);
+  arg0.value.string = const_cast<char*>(whole_prefix.c_str());
+
+  channel_args_.SetUserAgentPrefix(prefix);
+  EXPECT_TRUE(HasArg(arg0));
 }
 
 }  // namespace testing
diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc
index 65da71b391..c8523847ab 100644
--- a/test/cpp/end2end/end2end_test.cc
+++ b/test/cpp/end2end/end2end_test.cc
@@ -252,6 +252,9 @@ class End2endTest : public ::testing::TestWithParam<TestScenario> {
       args.SetSslTargetNameOverride("foo.test.google.fr");
       channel_creds = SslCredentials(ssl_opts);
     }
+    if (!user_agent_prefix_.empty()) {
+      args.SetUserAgentPrefix(user_agent_prefix_);
+    }
     args.SetString(GRPC_ARG_SECONDARY_USER_AGENT_STRING, "end2end_test");
     channel_ = CreateCustomChannel(server_address_.str(), channel_creds, args);
   }
@@ -285,6 +288,7 @@ class End2endTest : public ::testing::TestWithParam<TestScenario> {
   TestServiceImpl service_;
   TestServiceImpl special_service_;
   TestServiceImplDupPkg dup_pkg_service_;
+  grpc::string user_agent_prefix_;
 };
 
 static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs,
@@ -601,6 +605,25 @@ TEST_P(End2endServerTryCancelTest, BidiStreamServerCancelAfter) {
   TestBidiStreamServerCancel(CANCEL_AFTER_PROCESSING, 5);
 }
 
+TEST_P(End2endTest, SimpleRpcWithCustomeUserAgentPrefix) {
+  user_agent_prefix_ = "custom_prefix";
+  ResetStub();
+  EchoRequest request;
+  EchoResponse response;
+  request.set_message("Hello hello hello hello");
+  request.mutable_param()->set_echo_metadata(true);
+
+  ClientContext context;
+  Status s = stub_->Echo(&context, request, &response);
+  EXPECT_EQ(response.message(), request.message());
+  EXPECT_TRUE(s.ok());
+  const auto& trailing_metadata = context.GetServerTrailingMetadata();
+  auto iter = trailing_metadata.find("user-agent");
+  EXPECT_TRUE(iter != trailing_metadata.end());
+  grpc::string expected_prefix = user_agent_prefix_ + " grpc-c++/";
+  EXPECT_TRUE(iter->second.starts_with(expected_prefix));
+}
+
 TEST_P(End2endTest, MultipleRpcsWithVariedBinaryMetadataValue) {
   ResetStub();
   std::vector<std::thread*> threads;
-- 
GitLab