From 0c711ad88b632bea173bdea9ea24372052aa231d Mon Sep 17 00:00:00 2001 From: Julien Boeuf <jboeuf@google.com> Date: Fri, 28 Aug 2015 14:10:58 -0700 Subject: [PATCH] Adding C++ metadata processor. - Had to chnage the core API to add a destroy function pointer in grpc_auth_metadata_processor. - Tested end to end. - Fixed some issues in the server_auth_filter (we were not checking the length which put us at risk of an overflow). --- include/grpc++/auth_metadata_processor.h | 5 +- include/grpc++/server_credentials.h | 10 +- include/grpc++/support/auth_context.h | 5 +- include/grpc/grpc_security.h | 1 + src/core/security/credentials.c | 15 +- src/core/security/security_context.c | 13 -- src/core/security/server_auth_filter.c | 10 +- src/core/security/server_secure_chttp2.c | 9 + src/cpp/client/create_channel.cc | 1 + src/cpp/common/secure_auth_context.cc | 15 +- src/cpp/common/secure_auth_context.h | 7 +- src/cpp/common/secure_create_auth_context.cc | 2 +- src/cpp/server/secure_server_credentials.cc | 17 +- src/cpp/server/secure_server_credentials.h | 2 + .../end2end/fixtures/chttp2_fake_security.c | 2 +- .../fixtures/chttp2_simple_ssl_fullstack.c | 2 +- .../chttp2_simple_ssl_fullstack_with_poll.c | 2 +- .../chttp2_simple_ssl_fullstack_with_proxy.c | 2 +- .../chttp2_simple_ssl_with_oauth2_fullstack.c | 45 ++++- test/cpp/end2end/end2end_test.cc | 175 ++++++++++++++++-- test/cpp/util/messages.proto | 1 + 21 files changed, 276 insertions(+), 65 deletions(-) diff --git a/include/grpc++/auth_metadata_processor.h b/include/grpc++/auth_metadata_processor.h index c0631bc11f..a42abef416 100644 --- a/include/grpc++/auth_metadata_processor.h +++ b/include/grpc++/auth_metadata_processor.h @@ -58,7 +58,10 @@ class AuthMetadataProcessor { // from the passed-in auth_metadata. // consumed_auth_metadata needs to be filled with metadata that has been // consumed by the processor and will be removed from the call. - // TODO(jboeuf). + // response_metadata is the metadata that will be sent as part of the + // response. + // If the return value is not Status::OK, the rpc call will be aborted with + // the error code and error message sent back to the client. virtual Status Process(const InputMetadata& auth_metadata, AuthContext* context, OutputMetadata* consumed_auth_metadata, diff --git a/include/grpc++/server_credentials.h b/include/grpc++/server_credentials.h index 486c35c56b..e006f3a180 100644 --- a/include/grpc++/server_credentials.h +++ b/include/grpc++/server_credentials.h @@ -50,16 +50,16 @@ class ServerCredentials { public: virtual ~ServerCredentials(); + // This method is not thread-safe and has to be called before the server is + // started. The last call to this function wins. + virtual void SetAuthMetadataProcessor( + const std::shared_ptr<AuthMetadataProcessor>& processor) = 0; + private: friend class ::grpc::Server; virtual int AddPortToServer(const grpc::string& addr, grpc_server* server) = 0; - - // This method is not thread-safe and has to be called before the server is - // started. The last call to this function wins. - virtual void SetAuthMetadataProcessor( - const std::shared_ptr<AuthMetadataProcessor>& processor) = 0; }; // Options to create ServerCredentials with SSL diff --git a/include/grpc++/support/auth_context.h b/include/grpc++/support/auth_context.h index 5d5f8e837d..fc2701e806 100644 --- a/include/grpc++/support/auth_context.h +++ b/include/grpc++/support/auth_context.h @@ -77,6 +77,9 @@ class AuthContext { public: virtual ~AuthContext() {} + // Returns true if the peer is authenticated. + virtual bool IsPeerAuthenticated() const = 0; + // A peer identity, in general is one or more properties (in which case they // have the same name). virtual std::vector<grpc::string_ref> GetPeerIdentity() const = 0; @@ -92,7 +95,7 @@ class AuthContext { // Mutation functions: should only be used by an AuthMetadataProcessor. virtual void AddProperty(const grpc::string& key, - const grpc::string& value) = 0; + const grpc::string_ref& value) = 0; virtual bool SetPeerIdentityPropertyName(const grpc::string& name) = 0; }; diff --git a/include/grpc/grpc_security.h b/include/grpc/grpc_security.h index e2205eea30..0b540c0e66 100644 --- a/include/grpc/grpc_security.h +++ b/include/grpc/grpc_security.h @@ -293,6 +293,7 @@ typedef struct { void (*process)(void *state, grpc_auth_context *context, const grpc_metadata *md, size_t num_md, grpc_process_auth_metadata_done_cb cb, void *user_data); + void (*destroy)(void *state); void *state; } grpc_auth_metadata_processor; diff --git a/src/core/security/credentials.c b/src/core/security/credentials.c index 362d5f4b6f..cae80f0216 100644 --- a/src/core/security/credentials.c +++ b/src/core/security/credentials.c @@ -152,9 +152,20 @@ grpc_security_status grpc_server_credentials_create_security_connector( void grpc_server_credentials_set_auth_metadata_processor( grpc_server_credentials *creds, grpc_auth_metadata_processor processor) { if (creds == NULL) return; + if (creds->processor.destroy != NULL && creds->processor.state != NULL) { + creds->processor.destroy(creds->processor.state); + } creds->processor = processor; } +void grpc_server_credentials_destroy(grpc_server_credentials *creds) { + if (creds == NULL) return; + if (creds->processor.destroy != NULL && creds->processor.state != NULL) { + creds->processor.destroy(creds->processor.state); + } + gpr_free(creds); +} + /* -- Ssl credentials. -- */ static void ssl_destroy(grpc_credentials *creds) { @@ -185,7 +196,7 @@ static void ssl_server_destroy(grpc_server_credentials *creds) { gpr_free(c->config.pem_cert_chains_sizes); } if (c->config.pem_root_certs != NULL) gpr_free(c->config.pem_root_certs); - gpr_free(creds); + grpc_server_credentials_destroy(creds); } static int ssl_has_request_metadata(const grpc_credentials *creds) { return 0; } @@ -902,7 +913,7 @@ static void fake_transport_security_credentials_destroy( static void fake_transport_security_server_credentials_destroy( grpc_server_credentials *creds) { - gpr_free(creds); + grpc_server_credentials_destroy(creds); } static int fake_transport_security_has_request_metadata( diff --git a/src/core/security/security_context.c b/src/core/security/security_context.c index c1b434f302..95d80ba122 100644 --- a/src/core/security/security_context.c +++ b/src/core/security/security_context.c @@ -42,19 +42,6 @@ #include <grpc/support/log.h> #include <grpc/support/string_util.h> -/* --- grpc_process_auth_metadata_func --- */ - -static grpc_auth_metadata_processor server_processor = {NULL, NULL}; - -grpc_auth_metadata_processor grpc_server_get_auth_metadata_processor(void) { - return server_processor; -} - -void grpc_server_register_auth_metadata_processor( - grpc_auth_metadata_processor processor) { - server_processor = processor; -} - /* --- grpc_call --- */ grpc_call_error grpc_call_set_credentials(grpc_call *call, diff --git a/src/core/security/server_auth_filter.c b/src/core/security/server_auth_filter.c index 57729be32c..b767f85498 100644 --- a/src/core/security/server_auth_filter.c +++ b/src/core/security/server_auth_filter.c @@ -91,13 +91,17 @@ static grpc_mdelem *remove_consumed_md(void *user_data, grpc_mdelem *md) { call_data *calld = elem->call_data; size_t i; for (i = 0; i < calld->num_consumed_md; i++) { + const grpc_metadata *consumed_md = &calld->consumed_md[i]; /* Maybe we could do a pointer comparison but we do not have any guarantee that the metadata processor used the same pointers for consumed_md in the callback. */ - if (memcmp(GPR_SLICE_START_PTR(md->key->slice), calld->consumed_md[i].key, + if (GPR_SLICE_LENGTH(md->key->slice) != strlen(consumed_md->key) || + GPR_SLICE_LENGTH(md->value->slice) != consumed_md->value_length) { + continue; + } + if (memcmp(GPR_SLICE_START_PTR(md->key->slice), consumed_md->key, GPR_SLICE_LENGTH(md->key->slice)) == 0 && - memcmp(GPR_SLICE_START_PTR(md->value->slice), - calld->consumed_md[i].value, + memcmp(GPR_SLICE_START_PTR(md->value->slice), consumed_md->value, GPR_SLICE_LENGTH(md->value->slice)) == 0) { return NULL; /* Delete. */ } diff --git a/src/core/security/server_secure_chttp2.c b/src/core/security/server_secure_chttp2.c index 8d9d036d80..96ca4cbd76 100644 --- a/src/core/security/server_secure_chttp2.c +++ b/src/core/security/server_secure_chttp2.c @@ -79,6 +79,9 @@ static void state_unref(grpc_server_secure_state *state) { gpr_mu_unlock(&state->mu); /* clean up */ GRPC_SECURITY_CONNECTOR_UNREF(state->sc, "server"); + if (state->processor.state != NULL && state->processor.destroy != NULL) { + state->processor.destroy(state->processor.state); + } gpr_free(state); } } @@ -262,7 +265,13 @@ int grpc_server_add_secure_http2_port(grpc_server *server, const char *addr, state->server = server; state->tcp = tcp; state->sc = sc; + + /* Transfer ownership of the processor. */ state->processor = creds->processor; + creds->processor.state = NULL; + creds->processor.destroy = NULL; + creds->processor.process = NULL; + state->handshaking_tcp_endpoints = NULL; state->is_shutdown = 0; gpr_mu_init(&state->mu); diff --git a/src/cpp/client/create_channel.cc b/src/cpp/client/create_channel.cc index 8c571cbbaa..4f6f2dd068 100644 --- a/src/cpp/client/create_channel.cc +++ b/src/cpp/client/create_channel.cc @@ -46,6 +46,7 @@ class ChannelArguments; std::shared_ptr<Channel> CreateChannel( const grpc::string& target, const std::shared_ptr<Credentials>& creds, const ChannelArguments& args) { + GrpcLibrary init_lib; // We need to call init in case of a bad creds. ChannelArguments cp_args = args; std::ostringstream user_agent_prefix; user_agent_prefix << "grpc-c++/" << grpc_version_string(); diff --git a/src/cpp/common/secure_auth_context.cc b/src/cpp/common/secure_auth_context.cc index 823ad8b0e2..8615ac8aeb 100644 --- a/src/cpp/common/secure_auth_context.cc +++ b/src/cpp/common/secure_auth_context.cc @@ -37,9 +37,13 @@ namespace grpc { -SecureAuthContext::SecureAuthContext(grpc_auth_context* ctx) : ctx_(ctx) {} +SecureAuthContext::SecureAuthContext(grpc_auth_context* ctx, + bool take_ownership) + : ctx_(ctx), take_ownership_(take_ownership) {} -SecureAuthContext::~SecureAuthContext() { grpc_auth_context_release(ctx_); } +SecureAuthContext::~SecureAuthContext() { + if (take_ownership_) grpc_auth_context_release(ctx_); +} std::vector<grpc::string_ref> SecureAuthContext::GetPeerIdentity() const { if (!ctx_) { @@ -95,7 +99,7 @@ AuthPropertyIterator SecureAuthContext::end() const { } void SecureAuthContext::AddProperty(const grpc::string& key, - const grpc::string& value) { + const grpc::string_ref& value) { if (!ctx_) return; grpc_auth_context_add_property(ctx_, key.c_str(), value.data(), value.size()); } @@ -106,4 +110,9 @@ bool SecureAuthContext::SetPeerIdentityPropertyName(const grpc::string& name) { name.c_str()) != 0; } +bool SecureAuthContext::IsPeerAuthenticated() const { + if (!ctx_) return false; + return grpc_auth_context_peer_is_authenticated(ctx_) != 0; +} + } // namespace grpc diff --git a/src/cpp/common/secure_auth_context.h b/src/cpp/common/secure_auth_context.h index cc09a29ecb..6edab0cee1 100644 --- a/src/cpp/common/secure_auth_context.h +++ b/src/cpp/common/secure_auth_context.h @@ -42,10 +42,12 @@ namespace grpc { class SecureAuthContext GRPC_FINAL : public AuthContext { public: - SecureAuthContext(grpc_auth_context* ctx); + SecureAuthContext(grpc_auth_context* ctx, bool take_ownership); ~SecureAuthContext() GRPC_OVERRIDE; + bool IsPeerAuthenticated() const GRPC_OVERRIDE; + std::vector<grpc::string_ref> GetPeerIdentity() const GRPC_OVERRIDE; grpc::string GetPeerIdentityPropertyName() const GRPC_OVERRIDE; @@ -58,13 +60,14 @@ class SecureAuthContext GRPC_FINAL : public AuthContext { AuthPropertyIterator end() const GRPC_OVERRIDE; void AddProperty(const grpc::string& key, - const grpc::string& value) GRPC_OVERRIDE; + const grpc::string_ref& value) GRPC_OVERRIDE; virtual bool SetPeerIdentityPropertyName(const grpc::string& name) GRPC_OVERRIDE; private: grpc_auth_context* ctx_; + bool take_ownership_; }; } // namespace grpc diff --git a/src/cpp/common/secure_create_auth_context.cc b/src/cpp/common/secure_create_auth_context.cc index f13d25a1dd..5a907d0c0b 100644 --- a/src/cpp/common/secure_create_auth_context.cc +++ b/src/cpp/common/secure_create_auth_context.cc @@ -44,7 +44,7 @@ std::shared_ptr<const AuthContext> CreateAuthContext(grpc_call* call) { return std::shared_ptr<const AuthContext>(); } return std::shared_ptr<const AuthContext>( - new SecureAuthContext(grpc_call_auth_context(call))); + new SecureAuthContext(grpc_call_auth_context(call), true)); } } // namespace grpc diff --git a/src/cpp/server/secure_server_credentials.cc b/src/cpp/server/secure_server_credentials.cc index f8847b80f7..0e2927bafe 100644 --- a/src/cpp/server/secure_server_credentials.cc +++ b/src/cpp/server/secure_server_credentials.cc @@ -43,6 +43,11 @@ namespace grpc { +void AuthMetadataProcessorAyncWrapper::Destroy(void *wrapper) { + auto* w = reinterpret_cast<AuthMetadataProcessorAyncWrapper*>(wrapper); + delete w; +} + void AuthMetadataProcessorAyncWrapper::Process( void* wrapper, grpc_auth_context* context, const grpc_metadata* md, size_t num_md, grpc_process_auth_metadata_done_cb cb, void* user_data) { @@ -71,14 +76,14 @@ void AuthMetadataProcessorAyncWrapper::InvokeProcessor( metadata.insert(std::make_pair( md[i].key, grpc::string_ref(md[i].value, md[i].value_length))); } - SecureAuthContext context(ctx); + SecureAuthContext context(ctx, false); AuthMetadataProcessor::OutputMetadata consumed_metadata; AuthMetadataProcessor::OutputMetadata response_metadata; Status status = processor_->Process(metadata, &context, &consumed_metadata, &response_metadata); - std::vector<grpc_metadata> consumed_md(consumed_metadata.size()); + std::vector<grpc_metadata> consumed_md; for (auto it = consumed_metadata.begin(); it != consumed_metadata.end(); ++it) { consumed_md.push_back({it->first.c_str(), @@ -87,8 +92,7 @@ void AuthMetadataProcessorAyncWrapper::InvokeProcessor( 0, {{nullptr, nullptr, nullptr, nullptr}}}); } - - std::vector<grpc_metadata> response_md(response_metadata.size()); + std::vector<grpc_metadata> response_md; for (auto it = response_metadata.begin(); it != response_metadata.end(); ++it) { response_md.push_back({it->first.c_str(), @@ -109,9 +113,10 @@ int SecureServerCredentials::AddPortToServer(const grpc::string& addr, void SecureServerCredentials::SetAuthMetadataProcessor( const std::shared_ptr<AuthMetadataProcessor>& processor) { - processor_.reset(new AuthMetadataProcessorAyncWrapper(processor)); + auto *wrapper = new AuthMetadataProcessorAyncWrapper(processor); grpc_server_credentials_set_auth_metadata_processor( - creds_, {AuthMetadataProcessorAyncWrapper::Process, processor_.get()}); + creds_, {AuthMetadataProcessorAyncWrapper::Process, + AuthMetadataProcessorAyncWrapper::Destroy, wrapper}); } std::shared_ptr<ServerCredentials> SslServerCredentials( diff --git a/src/cpp/server/secure_server_credentials.h b/src/cpp/server/secure_server_credentials.h index d15b793b15..17be9e310d 100644 --- a/src/cpp/server/secure_server_credentials.h +++ b/src/cpp/server/secure_server_credentials.h @@ -46,6 +46,8 @@ namespace grpc { class AuthMetadataProcessorAyncWrapper GRPC_FINAL { public: + static void Destroy(void *wrapper); + static void Process(void* wrapper, grpc_auth_context* context, const grpc_metadata* md, size_t num_md, grpc_process_auth_metadata_done_cb cb, void* user_data); diff --git a/test/core/end2end/fixtures/chttp2_fake_security.c b/test/core/end2end/fixtures/chttp2_fake_security.c index b4a248fb52..3e64cc08e8 100644 --- a/test/core/end2end/fixtures/chttp2_fake_security.c +++ b/test/core/end2end/fixtures/chttp2_fake_security.c @@ -128,7 +128,7 @@ static void chttp2_init_server_fake_secure_fullstack( grpc_server_credentials *fake_ts_creds = grpc_fake_transport_security_server_credentials_create(); if (fail_server_auth_check(server_args)) { - grpc_auth_metadata_processor processor = {process_auth_failure, NULL}; + grpc_auth_metadata_processor processor = {process_auth_failure, NULL, NULL}; grpc_server_credentials_set_auth_metadata_processor(fake_ts_creds, processor); } diff --git a/test/core/end2end/fixtures/chttp2_simple_ssl_fullstack.c b/test/core/end2end/fixtures/chttp2_simple_ssl_fullstack.c index 201d202dff..9193a09b17 100644 --- a/test/core/end2end/fixtures/chttp2_simple_ssl_fullstack.c +++ b/test/core/end2end/fixtures/chttp2_simple_ssl_fullstack.c @@ -138,7 +138,7 @@ static void chttp2_init_server_simple_ssl_secure_fullstack( grpc_server_credentials *ssl_creds = grpc_ssl_server_credentials_create(NULL, &pem_cert_key_pair, 1, 0, NULL); if (fail_server_auth_check(server_args)) { - grpc_auth_metadata_processor processor = {process_auth_failure, NULL}; + grpc_auth_metadata_processor processor = {process_auth_failure, NULL, NULL}; grpc_server_credentials_set_auth_metadata_processor(ssl_creds, processor); } chttp2_init_server_secure_fullstack(f, server_args, ssl_creds); diff --git a/test/core/end2end/fixtures/chttp2_simple_ssl_fullstack_with_poll.c b/test/core/end2end/fixtures/chttp2_simple_ssl_fullstack_with_poll.c index e7375f15e6..2c605d1471 100644 --- a/test/core/end2end/fixtures/chttp2_simple_ssl_fullstack_with_poll.c +++ b/test/core/end2end/fixtures/chttp2_simple_ssl_fullstack_with_poll.c @@ -138,7 +138,7 @@ static void chttp2_init_server_simple_ssl_secure_fullstack( grpc_server_credentials *ssl_creds = grpc_ssl_server_credentials_create(NULL, &pem_cert_key_pair, 1, 0, NULL); if (fail_server_auth_check(server_args)) { - grpc_auth_metadata_processor processor = {process_auth_failure, NULL}; + grpc_auth_metadata_processor processor = {process_auth_failure, NULL, NULL}; grpc_server_credentials_set_auth_metadata_processor(ssl_creds, processor); } chttp2_init_server_secure_fullstack(f, server_args, ssl_creds); diff --git a/test/core/end2end/fixtures/chttp2_simple_ssl_fullstack_with_proxy.c b/test/core/end2end/fixtures/chttp2_simple_ssl_fullstack_with_proxy.c index be0dda25a6..8133a69a0c 100644 --- a/test/core/end2end/fixtures/chttp2_simple_ssl_fullstack_with_proxy.c +++ b/test/core/end2end/fixtures/chttp2_simple_ssl_fullstack_with_proxy.c @@ -167,7 +167,7 @@ static void chttp2_init_server_simple_ssl_secure_fullstack( grpc_server_credentials *ssl_creds = grpc_ssl_server_credentials_create(NULL, &pem_cert_key_pair, 1, 0, NULL); if (fail_server_auth_check(server_args)) { - grpc_auth_metadata_processor processor = {process_auth_failure, NULL}; + grpc_auth_metadata_processor processor = {process_auth_failure, NULL, NULL}; grpc_server_credentials_set_auth_metadata_processor(ssl_creds, processor); } chttp2_init_server_secure_fullstack(f, server_args, ssl_creds); diff --git a/test/core/end2end/fixtures/chttp2_simple_ssl_with_oauth2_fullstack.c b/test/core/end2end/fixtures/chttp2_simple_ssl_with_oauth2_fullstack.c index 9a545b1e3d..e61e276fff 100644 --- a/test/core/end2end/fixtures/chttp2_simple_ssl_with_oauth2_fullstack.c +++ b/test/core/end2end/fixtures/chttp2_simple_ssl_with_oauth2_fullstack.c @@ -67,13 +67,21 @@ static const grpc_metadata *find_metadata(const grpc_metadata *md, return NULL; } +typedef struct { + size_t pseudo_refcount; +} test_processor_state; + static void process_oauth2_success(void *state, grpc_auth_context *ctx, const grpc_metadata *md, size_t md_count, grpc_process_auth_metadata_done_cb cb, void *user_data) { const grpc_metadata *oauth2 = find_metadata(md, md_count, "Authorization", oauth2_md); - GPR_ASSERT(state == NULL); + test_processor_state *s; + + GPR_ASSERT(state != NULL); + s = (test_processor_state *)state; + GPR_ASSERT(s->pseudo_refcount == 1); GPR_ASSERT(oauth2 != NULL); grpc_auth_context_add_cstring_property(ctx, client_identity_property_name, client_identity); @@ -88,7 +96,10 @@ static void process_oauth2_failure(void *state, grpc_auth_context *ctx, void *user_data) { const grpc_metadata *oauth2 = find_metadata(md, md_count, "Authorization", oauth2_md); - GPR_ASSERT(state == NULL); + test_processor_state *s; + GPR_ASSERT(state != NULL); + s = (test_processor_state *)state; + GPR_ASSERT(s->pseudo_refcount == 1); GPR_ASSERT(oauth2 != NULL); cb(user_data, oauth2, 1, NULL, 0, GRPC_STATUS_UNAUTHENTICATED, NULL); } @@ -171,20 +182,34 @@ static int fail_server_auth_check(grpc_channel_args *server_args) { return 0; } +static void processor_destroy(void *state) { + test_processor_state *s = (test_processor_state *)state; + GPR_ASSERT((s->pseudo_refcount--) == 1); + gpr_free(s); +} + +static grpc_auth_metadata_processor test_processor_create(int failing) { + test_processor_state *s = gpr_malloc(sizeof(*s)); + grpc_auth_metadata_processor result; + s->pseudo_refcount = 1; + result.state = s; + result.destroy = processor_destroy; + if (failing) { + result.process = process_oauth2_failure; + } else { + result.process = process_oauth2_success; + } + return result; +} + static void chttp2_init_server_simple_ssl_secure_fullstack( grpc_end2end_test_fixture *f, grpc_channel_args *server_args) { grpc_ssl_pem_key_cert_pair pem_key_cert_pair = {test_server1_key, test_server1_cert}; grpc_server_credentials *ssl_creds = grpc_ssl_server_credentials_create(NULL, &pem_key_cert_pair, 1, 0, NULL); - grpc_auth_metadata_processor processor; - processor.state = NULL; - if (fail_server_auth_check(server_args)) { - processor.process = process_oauth2_failure; - } else { - processor.process = process_oauth2_success; - } - grpc_server_credentials_set_auth_metadata_processor(ssl_creds, processor); + grpc_server_credentials_set_auth_metadata_processor( + ssl_creds, test_processor_create(fail_server_auth_check(server_args))); chttp2_init_server_secure_fullstack(f, server_args, ssl_creds); } diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc index 9826837c60..72188e4da5 100644 --- a/test/cpp/end2end/end2end_test.cc +++ b/test/cpp/end2end/end2end_test.cc @@ -37,6 +37,7 @@ #include <grpc/grpc.h> #include <grpc/support/thd.h> #include <grpc/support/time.h> +#include <grpc++/auth_metadata_processor.h> #include <grpc++/channel.h> #include <grpc++/client_context.h> #include <grpc++/create_channel.h> @@ -79,14 +80,23 @@ void MaybeEchoDeadline(ServerContext* context, const EchoRequest* request, } } -void CheckServerAuthContext(const ServerContext* context) { +void CheckServerAuthContext(const ServerContext* context, + const grpc::string& expected_client_identity) { std::shared_ptr<const AuthContext> auth_ctx = context->auth_context(); std::vector<grpc::string_ref> ssl = auth_ctx->FindPropertyValues("transport_security_type"); EXPECT_EQ(1u, ssl.size()); EXPECT_EQ("ssl", ToString(ssl[0])); - EXPECT_TRUE(auth_ctx->GetPeerIdentityPropertyName().empty()); - EXPECT_TRUE(auth_ctx->GetPeerIdentity().empty()); + if (expected_client_identity.length() == 0) { + EXPECT_TRUE(auth_ctx->GetPeerIdentityPropertyName().empty()); + EXPECT_TRUE(auth_ctx->GetPeerIdentity().empty()); + EXPECT_FALSE(auth_ctx->IsPeerAuthenticated()); + } else { + auto identity = auth_ctx->GetPeerIdentity(); + EXPECT_TRUE(auth_ctx->IsPeerAuthenticated()); + EXPECT_EQ(1u, identity.size()); + EXPECT_EQ(expected_client_identity, identity[0]); + } } bool CheckIsLocalhost(const grpc::string& addr) { @@ -98,6 +108,54 @@ bool CheckIsLocalhost(const grpc::string& addr) { addr.substr(0, kIpv6.size()) == kIpv6; } +class TestAuthMetadataProcessor : public AuthMetadataProcessor { + public: + static const char kGoodGuy[]; + + TestAuthMetadataProcessor(bool is_blocking) : is_blocking_(is_blocking) {} + + std::shared_ptr<Credentials> GetCompatibleClientCreds() { + return AccessTokenCredentials(kGoodGuy); + } + std::shared_ptr<Credentials> GetIncompatibleClientCreds() { + return AccessTokenCredentials("Mr Hyde"); + } + + // Interface implementation + bool IsBlocking() const GRPC_OVERRIDE { return is_blocking_; } + + Status Process(const InputMetadata& auth_metadata, AuthContext* context, + OutputMetadata* consumed_auth_metadata, + OutputMetadata* response_metadata) GRPC_OVERRIDE { + EXPECT_TRUE(consumed_auth_metadata != nullptr); + EXPECT_TRUE(context != nullptr); + EXPECT_TRUE(response_metadata != nullptr); + auto auth_md = auth_metadata.find(GRPC_AUTHORIZATION_METADATA_KEY); + EXPECT_NE(auth_md, auth_metadata.end()); + string_ref auth_md_value = auth_md->second; + if (auth_md_value.ends_with("Dr Jekyll")) { + context->AddProperty(kIdentityPropName, kGoodGuy); + context->SetPeerIdentityPropertyName(kIdentityPropName); + consumed_auth_metadata->insert( + std::make_pair(string(auth_md->first.data(), auth_md->first.length()), + auth_md->second)); + return Status::OK; + } else { + return Status(StatusCode::UNAUTHENTICATED, + string("Invalid principal: ") + + string(auth_md_value.data(), auth_md_value.length())); + } + } + + protected: + static const char kIdentityPropName[]; + bool is_blocking_; +}; + +const char TestAuthMetadataProcessor::kGoodGuy[] = "Dr Jekyll"; +const char TestAuthMetadataProcessor::kIdentityPropName[] = "novel identity"; + + } // namespace class Proxy : public ::grpc::cpp::test::util::TestService::Service { @@ -162,8 +220,10 @@ class TestServiceImpl : public ::grpc::cpp::test::util::TestService::Service { ToString(iter->second)); } } - if (request->has_param() && request->param().check_auth_context()) { - CheckServerAuthContext(context); + if (request->has_param() && + (request->param().expected_client_identity().length() > 0 || + request->param().check_auth_context())) { + CheckServerAuthContext(context, request->param().expected_client_identity()); } if (request->has_param() && request->param().response_message_length() > 0) { @@ -259,9 +319,18 @@ class TestServiceImplDupPkg class End2endTest : public ::testing::TestWithParam<bool> { protected: End2endTest() - : kMaxMessageSize_(8192), special_service_("special") {} + : is_server_started_(false), + kMaxMessageSize_(8192), + special_service_("special") {} + + void TearDown() GRPC_OVERRIDE { + if (is_server_started_) { + server_->Shutdown(); + if (proxy_server_) proxy_server_->Shutdown(); + } + } - void SetUp() GRPC_OVERRIDE { + void StartServer(const std::shared_ptr<AuthMetadataProcessor>& processor) { int port = grpc_pick_unused_port_or_die(); server_address_ << "127.0.0.1:" << port; // Setup server @@ -271,22 +340,23 @@ class End2endTest : public ::testing::TestWithParam<bool> { SslServerCredentialsOptions ssl_opts; ssl_opts.pem_root_certs = ""; ssl_opts.pem_key_cert_pairs.push_back(pkcp); - builder.AddListeningPort(server_address_.str(), - SslServerCredentials(ssl_opts)); + auto server_creds = SslServerCredentials(ssl_opts); + server_creds->SetAuthMetadataProcessor(processor); + builder.AddListeningPort(server_address_.str(), server_creds); builder.RegisterService(&service_); builder.RegisterService("foo.test.youtube.com", &special_service_); builder.SetMaxMessageSize( kMaxMessageSize_); // For testing max message size. builder.RegisterService(&dup_pkg_service_); server_ = builder.BuildAndStart(); - } - - void TearDown() GRPC_OVERRIDE { - server_->Shutdown(); - if (proxy_server_) proxy_server_->Shutdown(); + is_server_started_ = true; } void ResetChannel() { + if (!is_server_started_) { + StartServer(std::shared_ptr<AuthMetadataProcessor>()); + } + EXPECT_TRUE(is_server_started_); SslCredentialsOptions ssl_opts = {test_root_cert, "", ""}; ChannelArguments args; args.SetSslTargetNameOverride("foo.test.google.fr"); @@ -314,6 +384,7 @@ class End2endTest : public ::testing::TestWithParam<bool> { stub_ = std::move(grpc::cpp::test::util::TestService::NewStub(channel_)); } + bool is_server_started_; std::shared_ptr<Channel> channel_; std::unique_ptr<grpc::cpp::test::util::TestService::Stub> stub_; std::unique_ptr<Server> server_; @@ -806,6 +877,82 @@ TEST_F(End2endTest, OverridePerCallCredentials) { EXPECT_TRUE(s.ok()); } +TEST_F(End2endTest, NonBlockingAuthMetadataProcessorSuccess) { + auto* processor = new TestAuthMetadataProcessor(false); + StartServer(std::shared_ptr<AuthMetadataProcessor>(processor)); + ResetStub(false); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(processor->GetCompatibleClientCreds()); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + request.mutable_param()->set_expected_client_identity( + TestAuthMetadataProcessor::kGoodGuy); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + + // Metadata should have been consumed by the processor. + EXPECT_FALSE(MetadataContains( + context.GetServerTrailingMetadata(), GRPC_AUTHORIZATION_METADATA_KEY, + grpc::string("Bearer ") + TestAuthMetadataProcessor::kGoodGuy)); +} + +TEST_F(End2endTest, NonBlockingAuthMetadataProcessorFailure) { + auto* processor = new TestAuthMetadataProcessor(false); + StartServer(std::shared_ptr<AuthMetadataProcessor>(processor)); + ResetStub(false); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(processor->GetIncompatibleClientCreds()); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAUTHENTICATED); +} + +TEST_F(End2endTest, BlockingAuthMetadataProcessorSuccess) { + auto* processor = new TestAuthMetadataProcessor(true); + StartServer(std::shared_ptr<AuthMetadataProcessor>(processor)); + ResetStub(false); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(processor->GetCompatibleClientCreds()); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + request.mutable_param()->set_expected_client_identity( + TestAuthMetadataProcessor::kGoodGuy); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + + // Metadata should have been consumed by the processor. + EXPECT_FALSE(MetadataContains( + context.GetServerTrailingMetadata(), GRPC_AUTHORIZATION_METADATA_KEY, + grpc::string("Bearer ") + TestAuthMetadataProcessor::kGoodGuy)); +} + +TEST_F(End2endTest, BlockingAuthMetadataProcessorFailure) { + auto* processor = new TestAuthMetadataProcessor(true); + StartServer(std::shared_ptr<AuthMetadataProcessor>(processor)); + ResetStub(false); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(processor->GetIncompatibleClientCreds()); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAUTHENTICATED); +} + // Client sends 20 requests and the server returns CANCELLED status after // reading 10 requests. TEST_F(End2endTest, RequestStreamServerEarlyCancelTest) { diff --git a/test/cpp/util/messages.proto b/test/cpp/util/messages.proto index 359d1db74f..a022707be9 100644 --- a/test/cpp/util/messages.proto +++ b/test/cpp/util/messages.proto @@ -40,6 +40,7 @@ message RequestParams { bool check_auth_context = 5; int32 response_message_length = 6; bool echo_peer = 7; + string expected_client_identity = 8; // will force check_auth_context. } message EchoRequest { -- GitLab