Skip to content
Snippets Groups Projects
Commit 24cb78d8 authored by Jan Tattermusch's avatar Jan Tattermusch Committed by GitHub
Browse files

Merge pull request #6751 from soltanmm/6522

Add a test for grpc/grpc#6522
parents a68b07e0 c045fd96
No related branches found
No related tags found
No related merge requests found
...@@ -143,22 +143,60 @@ class TypeSmokeTest(unittest.TestCase): ...@@ -143,22 +143,60 @@ class TypeSmokeTest(unittest.TestCase):
del completion_queue del completion_queue
class InsecureServerInsecureClient(unittest.TestCase): class ServerClientMixin(object):
def setUp(self): def setUpMixin(self, server_credentials, client_credentials, host_override):
self.server_completion_queue = cygrpc.CompletionQueue() self.server_completion_queue = cygrpc.CompletionQueue()
self.server = cygrpc.Server() self.server = cygrpc.Server()
self.server.register_completion_queue(self.server_completion_queue) self.server.register_completion_queue(self.server_completion_queue)
self.port = self.server.add_http2_port('[::]:0') if server_credentials:
self.port = self.server.add_http2_port('[::]:0', server_credentials)
else:
self.port = self.server.add_http2_port('[::]:0')
self.server.start() self.server.start()
self.client_completion_queue = cygrpc.CompletionQueue() self.client_completion_queue = cygrpc.CompletionQueue()
self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port)) if client_credentials:
client_channel_arguments = cygrpc.ChannelArgs([
def tearDown(self): cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
host_override)])
self.client_channel = cygrpc.Channel(
'localhost:{}'.format(self.port), client_channel_arguments,
client_credentials)
else:
self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port))
if host_override:
self.host_argument = None # default host
self.expected_host = host_override
else:
# arbitrary host name necessitating no further identification
self.host_argument = b'hostess'
self.expected_host = self.host_argument
def tearDownMixin(self):
del self.server del self.server
del self.client_completion_queue del self.client_completion_queue
del self.server_completion_queue del self.server_completion_queue
def _perform_operations(self, operations, call, queue, deadline, description):
"""Perform the list of operations with given call, queue, and deadline.
Invocation errors are reported with as an exception with `description` in
the message. Performs the operations asynchronously, returning a future.
"""
def performer():
tag = object()
try:
call_result = call.start_batch(cygrpc.Operations(operations), tag)
self.assertEqual(cygrpc.CallError.ok, call_result)
event = queue.poll(deadline)
self.assertEqual(cygrpc.CompletionType.operation_complete, event.type)
self.assertTrue(event.success)
self.assertIs(tag, event.tag)
except Exception as error:
raise Exception("Error in '{}': {}".format(description, error.message))
return event
return test_utilities.SimpleFuture(performer)
def testEcho(self): def testEcho(self):
DEADLINE = time.time()+5 DEADLINE = time.time()+5
DEADLINE_TOLERANCE = 0.25 DEADLINE_TOLERANCE = 0.25
...@@ -175,7 +213,6 @@ class InsecureServerInsecureClient(unittest.TestCase): ...@@ -175,7 +213,6 @@ class InsecureServerInsecureClient(unittest.TestCase):
REQUEST = b'in death a member of project mayhem has a name' REQUEST = b'in death a member of project mayhem has a name'
RESPONSE = b'his name is robert paulson' RESPONSE = b'his name is robert paulson'
METHOD = b'twinkies' METHOD = b'twinkies'
HOST = b'hostess'
cygrpc_deadline = cygrpc.Timespec(DEADLINE) cygrpc_deadline = cygrpc.Timespec(DEADLINE)
...@@ -188,7 +225,8 @@ class InsecureServerInsecureClient(unittest.TestCase): ...@@ -188,7 +225,8 @@ class InsecureServerInsecureClient(unittest.TestCase):
client_call_tag = object() client_call_tag = object()
client_call = self.client_channel.create_call( client_call = self.client_channel.create_call(
None, 0, self.client_completion_queue, METHOD, HOST, cygrpc_deadline) None, 0, self.client_completion_queue, METHOD, self.host_argument,
cygrpc_deadline)
client_initial_metadata = cygrpc.Metadata([ client_initial_metadata = cygrpc.Metadata([
cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY, cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
CLIENT_METADATA_ASCII_VALUE), CLIENT_METADATA_ASCII_VALUE),
...@@ -216,7 +254,8 @@ class InsecureServerInsecureClient(unittest.TestCase): ...@@ -216,7 +254,8 @@ class InsecureServerInsecureClient(unittest.TestCase):
test_common.metadata_transmitted(client_initial_metadata, test_common.metadata_transmitted(client_initial_metadata,
request_event.request_metadata)) request_event.request_metadata))
self.assertEqual(METHOD, request_event.request_call_details.method) self.assertEqual(METHOD, request_event.request_call_details.method)
self.assertEqual(HOST, request_event.request_call_details.host) self.assertEqual(self.expected_host,
request_event.request_call_details.host)
self.assertLess( self.assertLess(
abs(DEADLINE - float(request_event.request_call_details.deadline)), abs(DEADLINE - float(request_event.request_call_details.deadline)),
DEADLINE_TOLERANCE) DEADLINE_TOLERANCE)
...@@ -292,172 +331,101 @@ class InsecureServerInsecureClient(unittest.TestCase): ...@@ -292,172 +331,101 @@ class InsecureServerInsecureClient(unittest.TestCase):
del client_call del client_call
del server_call del server_call
def test6522(self):
class SecureServerSecureClient(unittest.TestCase):
def setUp(self):
server_credentials = cygrpc.server_credentials_ssl(
None, [cygrpc.SslPemKeyCertPair(resources.private_key(),
resources.certificate_chain())], False)
channel_credentials = cygrpc.channel_credentials_ssl(
resources.test_root_certificates(), None)
self.server_completion_queue = cygrpc.CompletionQueue()
self.server = cygrpc.Server()
self.server.register_completion_queue(self.server_completion_queue)
self.port = self.server.add_http2_port('[::]:0', server_credentials)
self.server.start()
self.client_completion_queue = cygrpc.CompletionQueue()
client_channel_arguments = cygrpc.ChannelArgs([
cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
_SSL_HOST_OVERRIDE)])
self.client_channel = cygrpc.Channel(
'localhost:{}'.format(self.port), client_channel_arguments,
channel_credentials)
def tearDown(self):
del self.server
del self.client_completion_queue
del self.server_completion_queue
def testEcho(self):
DEADLINE = time.time()+5 DEADLINE = time.time()+5
DEADLINE_TOLERANCE = 0.25 DEADLINE_TOLERANCE = 0.25
CLIENT_METADATA_ASCII_KEY = b'key' METHOD = b'twinkies'
CLIENT_METADATA_ASCII_VALUE = b'val'
CLIENT_METADATA_BIN_KEY = b'key-bin'
CLIENT_METADATA_BIN_VALUE = b'\0'*1000
SERVER_INITIAL_METADATA_KEY = b'init_me_me_me'
SERVER_INITIAL_METADATA_VALUE = b'whodawha?'
SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought'
SERVER_TRAILING_METADATA_VALUE = b'zomg it is'
SERVER_STATUS_CODE = cygrpc.StatusCode.ok
SERVER_STATUS_DETAILS = b'our work is never over'
REQUEST = b'in death a member of project mayhem has a name'
RESPONSE = b'his name is robert paulson'
METHOD = b'/twinkies'
HOST = None # Default host
cygrpc_deadline = cygrpc.Timespec(DEADLINE) cygrpc_deadline = cygrpc.Timespec(DEADLINE)
empty_metadata = cygrpc.Metadata([])
server_request_tag = object() server_request_tag = object()
request_call_result = self.server.request_call( self.server.request_call(
self.server_completion_queue, self.server_completion_queue, self.server_completion_queue, self.server_completion_queue,
server_request_tag) server_request_tag)
client_call = self.client_channel.create_call(
None, 0, self.client_completion_queue, METHOD, self.host_argument,
cygrpc_deadline)
self.assertEqual(cygrpc.CallError.ok, request_call_result) # Prologue
def perform_client_operations(operations, description):
plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '') return self._perform_operations(
call_credentials = cygrpc.call_credentials_metadata_plugin(plugin) operations, client_call,
self.client_completion_queue, cygrpc_deadline, description)
client_call_tag = object() client_event_future = perform_client_operations([
client_call = self.client_channel.create_call( cygrpc.operation_send_initial_metadata(empty_metadata,
None, 0, self.client_completion_queue, METHOD, HOST, cygrpc_deadline) _EMPTY_FLAGS),
client_call.set_credentials(call_credentials) cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
client_initial_metadata = cygrpc.Metadata([ ], "Client prologue")
cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
CLIENT_METADATA_ASCII_VALUE),
cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)])
client_start_batch_result = client_call.start_batch(cygrpc.Operations([
cygrpc.operation_send_initial_metadata(client_initial_metadata,
_EMPTY_FLAGS),
cygrpc.operation_send_message(REQUEST, _EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
]), client_call_tag)
self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
client_event_future = test_utilities.CompletionQueuePollFuture(
self.client_completion_queue, cygrpc_deadline)
request_event = self.server_completion_queue.poll(cygrpc_deadline) request_event = self.server_completion_queue.poll(cygrpc_deadline)
self.assertEqual(cygrpc.CompletionType.operation_complete,
request_event.type)
self.assertIsInstance(request_event.operation_call, cygrpc.Call)
self.assertIs(server_request_tag, request_event.tag)
self.assertEqual(0, len(request_event.batch_operations))
client_metadata_with_credentials = list(client_initial_metadata) + [
(_CALL_CREDENTIALS_METADATA_KEY, _CALL_CREDENTIALS_METADATA_VALUE)]
self.assertTrue(
test_common.metadata_transmitted(client_metadata_with_credentials,
request_event.request_metadata))
self.assertEqual(METHOD, request_event.request_call_details.method)
self.assertEqual(_SSL_HOST_OVERRIDE,
request_event.request_call_details.host)
self.assertLess(
abs(DEADLINE - float(request_event.request_call_details.deadline)),
DEADLINE_TOLERANCE)
server_call_tag = object()
server_call = request_event.operation_call server_call = request_event.operation_call
server_initial_metadata = cygrpc.Metadata([
cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY,
SERVER_INITIAL_METADATA_VALUE)])
server_trailing_metadata = cygrpc.Metadata([
cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
SERVER_TRAILING_METADATA_VALUE)])
server_start_batch_result = server_call.start_batch([
cygrpc.operation_send_initial_metadata(server_initial_metadata,
_EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_send_message(RESPONSE, _EMPTY_FLAGS),
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
server_trailing_metadata, SERVER_STATUS_CODE,
SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
], server_call_tag)
self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
client_event = client_event_future.result() def perform_server_operations(operations, description):
server_event = self.server_completion_queue.poll(cygrpc_deadline) return self._perform_operations(
operations, server_call,
self.server_completion_queue, cygrpc_deadline, description)
self.assertEqual(6, len(client_event.batch_operations)) server_event_future = perform_server_operations([
found_client_op_types = set() cygrpc.operation_send_initial_metadata(empty_metadata,
for client_result in client_event.batch_operations: _EMPTY_FLAGS),
# we expect each op type to be unique ], "Server prologue")
self.assertNotIn(client_result.type, found_client_op_types)
found_client_op_types.add(client_result.type)
if client_result.type == cygrpc.OperationType.receive_initial_metadata:
self.assertTrue(
test_common.metadata_transmitted(server_initial_metadata,
client_result.received_metadata))
elif client_result.type == cygrpc.OperationType.receive_message:
self.assertEqual(RESPONSE, client_result.received_message.bytes())
elif client_result.type == cygrpc.OperationType.receive_status_on_client:
self.assertTrue(
test_common.metadata_transmitted(server_trailing_metadata,
client_result.received_metadata))
self.assertEqual(SERVER_STATUS_DETAILS,
client_result.received_status_details)
self.assertEqual(SERVER_STATUS_CODE, client_result.received_status_code)
self.assertEqual(set([
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.send_message,
cygrpc.OperationType.send_close_from_client,
cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.receive_status_on_client
]), found_client_op_types)
self.assertEqual(5, len(server_event.batch_operations)) client_event_future.result() # force completion
found_server_op_types = set() server_event_future.result()
for server_result in server_event.batch_operations:
self.assertNotIn(client_result.type, found_server_op_types)
found_server_op_types.add(server_result.type)
if server_result.type == cygrpc.OperationType.receive_message:
self.assertEqual(REQUEST, server_result.received_message.bytes())
elif server_result.type == cygrpc.OperationType.receive_close_on_server:
self.assertFalse(server_result.received_cancelled)
self.assertEqual(set([
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.send_message,
cygrpc.OperationType.receive_close_on_server,
cygrpc.OperationType.send_status_from_server
]), found_server_op_types)
del client_call # Messaging
del server_call for _ in range(10):
client_event_future = perform_client_operations([
cygrpc.operation_send_message(b'', _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
], "Client message")
server_event_future = perform_server_operations([
cygrpc.operation_send_message(b'', _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
], "Server receive")
client_event_future.result() # force completion
server_event_future.result()
# Epilogue
client_event_future = perform_client_operations([
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
], "Client epilogue")
server_event_future = perform_server_operations([
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS)
], "Server epilogue")
client_event_future.result() # force completion
server_event_future.result()
class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
def setUp(self):
self.setUpMixin(None, None, None)
def tearDown(self):
self.tearDownMixin()
class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
def setUp(self):
server_credentials = cygrpc.server_credentials_ssl(
None, [cygrpc.SslPemKeyCertPair(resources.private_key(),
resources.certificate_chain())], False)
client_credentials = cygrpc.channel_credentials_ssl(
resources.test_root_certificates(), None)
self.setUpMixin(server_credentials, client_credentials, _SSL_HOST_OVERRIDE)
def tearDown(self):
self.tearDownMixin()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -32,15 +32,35 @@ import threading ...@@ -32,15 +32,35 @@ import threading
from grpc._cython import cygrpc from grpc._cython import cygrpc
class CompletionQueuePollFuture: class SimpleFuture(object):
"""A simple future mechanism."""
def __init__(self, completion_queue, deadline): def __init__(self, function, *args, **kwargs):
def poller_function(): def wrapped_function():
self._event_result = completion_queue.poll(deadline) try:
self._event_result = None self._result = function(*args, **kwargs)
self._thread = threading.Thread(target=poller_function) except Exception as error:
self._error = error
self._result = None
self._error = None
self._thread = threading.Thread(target=wrapped_function)
self._thread.start() self._thread.start()
def result(self): def result(self):
"""The resulting value of this future.
Re-raises any exceptions.
"""
self._thread.join() self._thread.join()
return self._event_result if self._error:
# TODO(atash): re-raise exceptions in a way that preserves tracebacks
raise self._error
return self._result
class CompletionQueuePollFuture(SimpleFuture):
def __init__(self, completion_queue, deadline):
super(CompletionQueuePollFuture, self).__init__(
lambda: completion_queue.poll(deadline))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment