Skip to content
Snippets Groups Projects
Commit db13f68d authored by Nathaniel Manista's avatar Nathaniel Manista
Browse files

Add a server_host_override to stub creation

This optional value should only be passed in tests.
parent a65e4631
No related branches found
No related tags found
No related merge requests found
...@@ -70,7 +70,8 @@ class _CTest(unittest.TestCase): ...@@ -70,7 +70,8 @@ class _CTest(unittest.TestCase):
def testChannel(self): def testChannel(self):
_c.init() _c.init()
channel = _c.Channel('test host:12345', None) channel = _c.Channel(
'test host:12345', None, server_host_override='ignored')
del channel del channel
_c.shut_down() _c.shut_down()
......
...@@ -42,19 +42,35 @@ ...@@ -42,19 +42,35 @@
static int pygrpc_channel_init(Channel *self, PyObject *args, PyObject *kwds) { static int pygrpc_channel_init(Channel *self, PyObject *args, PyObject *kwds) {
const char *hostport; const char *hostport;
PyObject *client_credentials; PyObject *client_credentials;
static char *kwlist[] = {"hostport", "client_credentials", NULL}; char *server_host_override = NULL;
static char *kwlist[] = {"hostport", "client_credentials",
"server_host_override", NULL};
grpc_arg server_host_override_arg;
grpc_channel_args channel_args;
if (!(PyArg_ParseTupleAndKeywords(args, kwds, "sO:Channel", kwlist, if (!(PyArg_ParseTupleAndKeywords(args, kwds, "sO|z:Channel", kwlist,
&hostport, &client_credentials))) { &hostport, &client_credentials,
&server_host_override))) {
return -1; return -1;
} }
if (client_credentials == Py_None) { if (client_credentials == Py_None) {
self->c_channel = grpc_channel_create(hostport, NULL); self->c_channel = grpc_channel_create(hostport, NULL);
return 0; return 0;
} else { } else {
self->c_channel = grpc_secure_channel_create( if (server_host_override == NULL) {
((ClientCredentials *)client_credentials)->c_client_credentials, self->c_channel = grpc_secure_channel_create(
hostport, NULL); ((ClientCredentials *)client_credentials)->c_client_credentials,
hostport, NULL);
} else {
server_host_override_arg.type = GRPC_ARG_STRING;
server_host_override_arg.key = GRPC_SSL_TARGET_NAME_OVERRIDE_ARG;
server_host_override_arg.value.string = server_host_override;
channel_args.num_args = 1;
channel_args.args = &server_host_override_arg;
self->c_channel = grpc_secure_channel_create(
((ClientCredentials *)client_credentials)->c_client_credentials,
hostport, &channel_args);
}
return 0; return 0;
} }
} }
......
...@@ -93,7 +93,8 @@ class RearLink(ticket_interfaces.RearLink, activated.Activated): ...@@ -93,7 +93,8 @@ class RearLink(ticket_interfaces.RearLink, activated.Activated):
def __init__( def __init__(
self, host, port, pool, request_serializers, response_deserializers, self, host, port, pool, request_serializers, response_deserializers,
secure, root_certificates, private_key, certificate_chain): secure, root_certificates, private_key, certificate_chain,
server_host_override=None):
"""Constructor. """Constructor.
Args: Args:
...@@ -111,6 +112,8 @@ class RearLink(ticket_interfaces.RearLink, activated.Activated): ...@@ -111,6 +112,8 @@ class RearLink(ticket_interfaces.RearLink, activated.Activated):
key should be used. key should be used.
certificate_chain: The PEM-encoded certificate chain to use or None if certificate_chain: The PEM-encoded certificate chain to use or None if
no certificate chain should be used. no certificate chain should be used.
server_host_override: (For testing only) the target name used for SSL
host name checking.
""" """
self._condition = threading.Condition() self._condition = threading.Condition()
self._host = host self._host = host
...@@ -132,6 +135,7 @@ class RearLink(ticket_interfaces.RearLink, activated.Activated): ...@@ -132,6 +135,7 @@ class RearLink(ticket_interfaces.RearLink, activated.Activated):
self._root_certificates = root_certificates self._root_certificates = root_certificates
self._private_key = private_key self._private_key = private_key
self._certificate_chain = certificate_chain self._certificate_chain = certificate_chain
self._server_host_override = server_host_override
def _on_write_event(self, operation_id, event, rpc_state): def _on_write_event(self, operation_id, event, rpc_state):
if event.write_accepted: if event.write_accepted:
...@@ -327,7 +331,8 @@ class RearLink(ticket_interfaces.RearLink, activated.Activated): ...@@ -327,7 +331,8 @@ class RearLink(ticket_interfaces.RearLink, activated.Activated):
with self._condition: with self._condition:
self._completion_queue = _low.CompletionQueue() self._completion_queue = _low.CompletionQueue()
self._channel = _low.Channel( self._channel = _low.Channel(
'%s:%d' % (self._host, self._port), self._client_credentials) '%s:%d' % (self._host, self._port), self._client_credentials,
server_host_override=self._server_host_override)
return self return self
def _stop(self): def _stop(self):
...@@ -388,7 +393,8 @@ class _ActivatedRearLink(ticket_interfaces.RearLink, activated.Activated): ...@@ -388,7 +393,8 @@ class _ActivatedRearLink(ticket_interfaces.RearLink, activated.Activated):
def __init__( def __init__(
self, host, port, request_serializers, response_deserializers, secure, self, host, port, request_serializers, response_deserializers, secure,
root_certificates, private_key, certificate_chain): root_certificates, private_key, certificate_chain,
server_host_override=None):
self._host = host self._host = host
self._port = port self._port = port
self._request_serializers = request_serializers self._request_serializers = request_serializers
...@@ -397,6 +403,7 @@ class _ActivatedRearLink(ticket_interfaces.RearLink, activated.Activated): ...@@ -397,6 +403,7 @@ class _ActivatedRearLink(ticket_interfaces.RearLink, activated.Activated):
self._root_certificates = root_certificates self._root_certificates = root_certificates
self._private_key = private_key self._private_key = private_key
self._certificate_chain = certificate_chain self._certificate_chain = certificate_chain
self._server_host_override = server_host_override
self._lock = threading.Lock() self._lock = threading.Lock()
self._pool = None self._pool = None
...@@ -415,7 +422,8 @@ class _ActivatedRearLink(ticket_interfaces.RearLink, activated.Activated): ...@@ -415,7 +422,8 @@ class _ActivatedRearLink(ticket_interfaces.RearLink, activated.Activated):
self._rear_link = RearLink( self._rear_link = RearLink(
self._host, self._port, self._pool, self._request_serializers, self._host, self._port, self._pool, self._request_serializers,
self._response_deserializers, self._secure, self._root_certificates, self._response_deserializers, self._secure, self._root_certificates,
self._private_key, self._certificate_chain) self._private_key, self._certificate_chain,
server_host_override=self._server_host_override)
self._rear_link.join_fore_link(self._fore_link) self._rear_link.join_fore_link(self._fore_link)
self._rear_link.start() self._rear_link.start()
return self return self
...@@ -477,7 +485,7 @@ def activated_rear_link( ...@@ -477,7 +485,7 @@ def activated_rear_link(
def secure_activated_rear_link( def secure_activated_rear_link(
host, port, request_serializers, response_deserializers, root_certificates, host, port, request_serializers, response_deserializers, root_certificates,
private_key, certificate_chain): private_key, certificate_chain, server_host_override=None):
"""Creates a RearLink that is also an activated.Activated. """Creates a RearLink that is also an activated.Activated.
The returned object is only valid for use between calls to its start and stop The returned object is only valid for use between calls to its start and stop
...@@ -496,7 +504,10 @@ def secure_activated_rear_link( ...@@ -496,7 +504,10 @@ def secure_activated_rear_link(
should be used. should be used.
certificate_chain: The PEM-encoded certificate chain to use or None if no certificate_chain: The PEM-encoded certificate chain to use or None if no
certificate chain should be used. certificate chain should be used.
server_host_override: (For testing only) the target name used for SSL
host name checking.
""" """
return _ActivatedRearLink( return _ActivatedRearLink(
host, port, request_serializers, response_deserializers, True, host, port, request_serializers, response_deserializers, True,
root_certificates, private_key, certificate_chain) root_certificates, private_key, certificate_chain,
server_host_override=server_host_override)
...@@ -125,7 +125,8 @@ def insecure_stub(methods, host, port): ...@@ -125,7 +125,8 @@ def insecure_stub(methods, host, port):
def secure_stub( def secure_stub(
methods, host, port, root_certificates, private_key, certificate_chain): methods, host, port, root_certificates, private_key, certificate_chain,
server_host_override=None):
"""Constructs an insecure interfaces.Stub. """Constructs an insecure interfaces.Stub.
Args: Args:
...@@ -140,6 +141,8 @@ def secure_stub( ...@@ -140,6 +141,8 @@ def secure_stub(
should be used. should be used.
certificate_chain: The PEM-encoded certificate chain to use or None if no certificate_chain: The PEM-encoded certificate chain to use or None if no
certificate chain should be used. certificate chain should be used.
server_host_override: (For testing only) the target name used for SSL
host name checking.
Returns: Returns:
An interfaces.Stub affording RPC invocation. An interfaces.Stub affording RPC invocation.
...@@ -148,7 +151,7 @@ def secure_stub( ...@@ -148,7 +151,7 @@ def secure_stub(
activated_rear_link = _rear.secure_activated_rear_link( activated_rear_link = _rear.secure_activated_rear_link(
host, port, breakdown.request_serializers, host, port, breakdown.request_serializers,
breakdown.response_deserializers, root_certificates, private_key, breakdown.response_deserializers, root_certificates, private_key,
certificate_chain) certificate_chain, server_host_override=server_host_override)
return _build_stub(breakdown, activated_rear_link) return _build_stub(breakdown, activated_rear_link)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment