Skip to content
Snippets Groups Projects
Commit 4350e748 authored by Makarand Dharmapurikar's avatar Makarand Dharmapurikar
Browse files

ability to deal with multiple streams in flight.

parent 28d19800
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,7 @@ from twisted.internet.protocol import Protocol ...@@ -6,6 +6,7 @@ from twisted.internet.protocol import Protocol
from twisted.internet import reactor from twisted.internet import reactor
from h2.connection import H2Connection from h2.connection import H2Connection
from h2.events import RequestReceived, DataReceived, WindowUpdated, RemoteSettingsChanged, PingAcknowledged from h2.events import RequestReceived, DataReceived, WindowUpdated, RemoteSettingsChanged, PingAcknowledged
from h2.exceptions import ProtocolError
READ_CHUNK_SIZE = 16384 READ_CHUNK_SIZE = 16384
GRPC_HEADER_SIZE = 5 GRPC_HEADER_SIZE = 5
...@@ -13,7 +14,7 @@ GRPC_HEADER_SIZE = 5 ...@@ -13,7 +14,7 @@ GRPC_HEADER_SIZE = 5
class H2ProtocolBaseServer(Protocol): class H2ProtocolBaseServer(Protocol):
def __init__(self): def __init__(self):
self._conn = H2Connection(client_side=False) self._conn = H2Connection(client_side=False)
self._recv_buffer = '' self._recv_buffer = {}
self._handlers = {} self._handlers = {}
self._handlers['ConnectionMade'] = self.on_connection_made_default self._handlers['ConnectionMade'] = self.on_connection_made_default
self._handlers['DataReceived'] = self.on_data_received_default self._handlers['DataReceived'] = self.on_data_received_default
...@@ -23,6 +24,7 @@ class H2ProtocolBaseServer(Protocol): ...@@ -23,6 +24,7 @@ class H2ProtocolBaseServer(Protocol):
self._handlers['ConnectionLost'] = self.on_connection_lost self._handlers['ConnectionLost'] = self.on_connection_lost
self._handlers['PingAcknowledged'] = self.on_ping_acknowledged_default self._handlers['PingAcknowledged'] = self.on_ping_acknowledged_default
self._stream_status = {} self._stream_status = {}
self._send_remaining = {}
self._outstanding_pings = 0 self._outstanding_pings = 0
def set_handlers(self, handlers): def set_handlers(self, handlers):
...@@ -45,18 +47,23 @@ class H2ProtocolBaseServer(Protocol): ...@@ -45,18 +47,23 @@ class H2ProtocolBaseServer(Protocol):
reactor.callFromThread(reactor.stop) reactor.callFromThread(reactor.stop)
def dataReceived(self, data): def dataReceived(self, data):
events = self._conn.receive_data(data) try:
events = self._conn.receive_data(data)
except ProtocolError:
# this try/except block catches exceptions due to race between sending
# GOAWAY and processing a response in flight.
return
if self._conn.data_to_send: if self._conn.data_to_send:
self.transport.write(self._conn.data_to_send()) self.transport.write(self._conn.data_to_send())
for event in events: for event in events:
if isinstance(event, RequestReceived) and self._handlers.has_key('RequestReceived'): if isinstance(event, RequestReceived) and self._handlers.has_key('RequestReceived'):
logging.info('RequestReceived Event') logging.info('RequestReceived Event for stream: %d'%event.stream_id)
self._handlers['RequestReceived'](event) self._handlers['RequestReceived'](event)
elif isinstance(event, DataReceived) and self._handlers.has_key('DataReceived'): elif isinstance(event, DataReceived) and self._handlers.has_key('DataReceived'):
logging.info('DataReceived Event') logging.info('DataReceived Event for stream: %d'%event.stream_id)
self._handlers['DataReceived'](event) self._handlers['DataReceived'](event)
elif isinstance(event, WindowUpdated) and self._handlers.has_key('WindowUpdated'): elif isinstance(event, WindowUpdated) and self._handlers.has_key('WindowUpdated'):
logging.info('WindowUpdated Event') logging.info('WindowUpdated Event for stream: %d'%event.stream_id)
self._handlers['WindowUpdated'](event) self._handlers['WindowUpdated'](event)
elif isinstance(event, PingAcknowledged) and self._handlers.has_key('PingAcknowledged'): elif isinstance(event, PingAcknowledged) and self._handlers.has_key('PingAcknowledged'):
logging.info('PingAcknowledged Event') logging.info('PingAcknowledged Event')
...@@ -68,10 +75,10 @@ class H2ProtocolBaseServer(Protocol): ...@@ -68,10 +75,10 @@ class H2ProtocolBaseServer(Protocol):
def on_data_received_default(self, event): def on_data_received_default(self, event):
self._conn.acknowledge_received_data(len(event.data), event.stream_id) self._conn.acknowledge_received_data(len(event.data), event.stream_id)
self._recv_buffer += event.data self._recv_buffer[event.stream_id] += event.data
def on_request_received_default(self, event): def on_request_received_default(self, event):
self._recv_buffer = '' self._recv_buffer[event.stream_id] = ''
self._stream_id = event.stream_id self._stream_id = event.stream_id
self._stream_status[event.stream_id] = True self._stream_status[event.stream_id] = True
self._conn.send_headers( self._conn.send_headers(
...@@ -86,48 +93,57 @@ class H2ProtocolBaseServer(Protocol): ...@@ -86,48 +93,57 @@ class H2ProtocolBaseServer(Protocol):
self.transport.write(self._conn.data_to_send()) self.transport.write(self._conn.data_to_send())
def on_window_update_default(self, event): def on_window_update_default(self, event):
pass # send pending data, if any
self.default_send(event.stream_id)
def send_reset_stream(self): def send_reset_stream(self):
self._conn.reset_stream(self._stream_id) self._conn.reset_stream(self._stream_id)
self.transport.write(self._conn.data_to_send()) self.transport.write(self._conn.data_to_send())
def setup_send(self, data_to_send): def setup_send(self, data_to_send, stream_id):
self._send_remaining = len(data_to_send) logging.info('Setting up data to send for stream_id: %d'%stream_id)
self._send_remaining[stream_id] = len(data_to_send)
self._send_offset = 0 self._send_offset = 0
self._data_to_send = data_to_send self._data_to_send = data_to_send
self.default_send() self.default_send(stream_id)
def default_send(self): def default_send(self, stream_id):
while self._send_remaining > 0: if not self._send_remaining.has_key(stream_id):
lfcw = self._conn.local_flow_control_window(self._stream_id) # not setup to send data yet
return
while self._send_remaining[stream_id] > 0:
if self._stream_status[stream_id] is False:
logging.info('Stream %d is closed.'%stream_id)
break
lfcw = self._conn.local_flow_control_window(stream_id)
if lfcw == 0: if lfcw == 0:
break break
chunk_size = min(lfcw, READ_CHUNK_SIZE) chunk_size = min(lfcw, READ_CHUNK_SIZE)
bytes_to_send = min(chunk_size, self._send_remaining) bytes_to_send = min(chunk_size, self._send_remaining[stream_id])
logging.info('flow_control_window = %d. sending [%d:%d] stream_id %d'% logging.info('flow_control_window = %d. sending [%d:%d] stream_id %d'%
(lfcw, self._send_offset, self._send_offset + bytes_to_send, (lfcw, self._send_offset, self._send_offset + bytes_to_send,
self._stream_id)) stream_id))
data = self._data_to_send[self._send_offset : self._send_offset + bytes_to_send] data = self._data_to_send[self._send_offset : self._send_offset + bytes_to_send]
self._conn.send_data(self._stream_id, data, False) self._conn.send_data(stream_id, data, False)
self._send_remaining -= bytes_to_send self._send_remaining[stream_id] -= bytes_to_send
self._send_offset += bytes_to_send self._send_offset += bytes_to_send
if self._send_remaining == 0: if self._send_remaining[stream_id] == 0:
self._handlers['SendDone']() self._handlers['SendDone'](stream_id)
def default_ping(self): def default_ping(self):
self._outstanding_pings += 1 self._outstanding_pings += 1
self._conn.ping(b'\x00'*8) self._conn.ping(b'\x00'*8)
self.transport.write(self._conn.data_to_send()) self.transport.write(self._conn.data_to_send())
def on_send_done_default(self): def on_send_done_default(self, stream_id):
if self._stream_status[self._stream_id]: if self._stream_status[stream_id]:
self._stream_status[self._stream_id] = False self._stream_status[stream_id] = False
self.default_send_trailer() self.default_send_trailer(stream_id)
def default_send_trailer(self): def default_send_trailer(self, stream_id):
logging.info('Sending trailer for stream id %d'%self._stream_id) logging.info('Sending trailer for stream id %d'%stream_id)
self._conn.send_headers(self._stream_id, self._conn.send_headers(stream_id,
headers=[ ('grpc-status', '0') ], headers=[ ('grpc-status', '0') ],
end_stream=True end_stream=True
) )
...@@ -141,8 +157,8 @@ class H2ProtocolBaseServer(Protocol): ...@@ -141,8 +157,8 @@ class H2ProtocolBaseServer(Protocol):
response_data = b'\x00' + struct.pack('i', len(serialized_resp_proto))[::-1] + serialized_resp_proto response_data = b'\x00' + struct.pack('i', len(serialized_resp_proto))[::-1] + serialized_resp_proto
return response_data return response_data
@staticmethod def parse_received_data(self, stream_id):
def parse_received_data(recv_buffer): recv_buffer = self._recv_buffer[stream_id]
""" returns a grpc framed string of bytes containing response proto of the size """ returns a grpc framed string of bytes containing response proto of the size
asked in request """ asked in request """
grpc_msg_size = struct.unpack('i',recv_buffer[1:5][::-1])[0] grpc_msg_size = struct.unpack('i',recv_buffer[1:5][::-1])[0]
...@@ -152,5 +168,5 @@ class H2ProtocolBaseServer(Protocol): ...@@ -152,5 +168,5 @@ class H2ProtocolBaseServer(Protocol):
req_proto_str = recv_buffer[5:5+grpc_msg_size] req_proto_str = recv_buffer[5:5+grpc_msg_size]
sr = messages_pb2.SimpleRequest() sr = messages_pb2.SimpleRequest()
sr.ParseFromString(req_proto_str) sr.ParseFromString(req_proto_str)
logging.info('Parsed request: response_size=%s'%sr.response_size) logging.info('Parsed request for stream %d: response_size=%s'%(stream_id, sr.response_size))
return sr return sr
...@@ -12,7 +12,6 @@ class TestcaseGoaway(object): ...@@ -12,7 +12,6 @@ class TestcaseGoaway(object):
self._base_server = http2_base_server.H2ProtocolBaseServer() self._base_server = http2_base_server.H2ProtocolBaseServer()
self._base_server._handlers['RequestReceived'] = self.on_request_received self._base_server._handlers['RequestReceived'] = self.on_request_received
self._base_server._handlers['DataReceived'] = self.on_data_received self._base_server._handlers['DataReceived'] = self.on_data_received
self._base_server._handlers['WindowUpdated'] = self.on_window_update_default
self._base_server._handlers['SendDone'] = self.on_send_done self._base_server._handlers['SendDone'] = self.on_send_done
self._base_server._handlers['ConnectionLost'] = self.on_connection_lost self._base_server._handlers['ConnectionLost'] = self.on_connection_lost
self._ready_to_send = False self._ready_to_send = False
...@@ -27,11 +26,11 @@ class TestcaseGoaway(object): ...@@ -27,11 +26,11 @@ class TestcaseGoaway(object):
if self._iteration == 2: if self._iteration == 2:
self._base_server.on_connection_lost(reason) self._base_server.on_connection_lost(reason)
def on_send_done(self): def on_send_done(self, stream_id):
self._base_server.on_send_done_default() self._base_server.on_send_done_default(stream_id)
if self._base_server._stream_id == 1: logging.info('Sending GOAWAY for stream %d:'%stream_id)
logging.info('Sending GOAWAY for stream 1') self._base_server._conn.close_connection(error_code=0, additional_data=None, last_stream_id=stream_id)
self._base_server._conn.close_connection(error_code=0, additional_data=None, last_stream_id=1) self._base_server._stream_status[stream_id] = False
def on_request_received(self, event): def on_request_received(self, event):
self._ready_to_send = False self._ready_to_send = False
...@@ -39,13 +38,9 @@ class TestcaseGoaway(object): ...@@ -39,13 +38,9 @@ class TestcaseGoaway(object):
def on_data_received(self, event): def on_data_received(self, event):
self._base_server.on_data_received_default(event) self._base_server.on_data_received_default(event)
sr = self._base_server.parse_received_data(self._base_server._recv_buffer) sr = self._base_server.parse_received_data(event.stream_id)
if sr: if sr:
logging.info('Creating response size = %s'%sr.response_size) logging.info('Creating response size = %s'%sr.response_size)
response_data = self._base_server.default_response_data(sr.response_size) response_data = self._base_server.default_response_data(sr.response_size)
self._ready_to_send = True self._ready_to_send = True
self._base_server.setup_send(response_data) self._base_server.setup_send(response_data, event.stream_id)
def on_window_update_default(self, event):
if self._ready_to_send:
self._base_server.default_send()
...@@ -24,7 +24,8 @@ class TestcaseSettingsMaxStreams(object): ...@@ -24,7 +24,8 @@ class TestcaseSettingsMaxStreams(object):
def on_data_received(self, event): def on_data_received(self, event):
self._base_server.on_data_received_default(event) self._base_server.on_data_received_default(event)
sr = self._base_server.parse_received_data(self._base_server._recv_buffer) sr = self._base_server.parse_received_data(event.stream_id)
logging.info('Creating response size = %s'%sr.response_size) if sr:
response_data = self._base_server.default_response_data(sr.response_size) logging.info('Creating response of size = %s'%sr.response_size)
self._base_server.setup_send(response_data) response_data = self._base_server.default_response_data(sr.response_size)
self._base_server.setup_send(response_data, event.stream_id)
...@@ -23,12 +23,13 @@ class TestcasePing(object): ...@@ -23,12 +23,13 @@ class TestcasePing(object):
def on_data_received(self, event): def on_data_received(self, event):
self._base_server.on_data_received_default(event) self._base_server.on_data_received_default(event)
sr = self._base_server.parse_received_data(self._base_server._recv_buffer) sr = self._base_server.parse_received_data(event.stream_id)
logging.info('Creating response size = %s'%sr.response_size) if sr:
response_data = self._base_server.default_response_data(sr.response_size) logging.info('Creating response size = %s'%sr.response_size)
self._base_server.default_ping() response_data = self._base_server.default_response_data(sr.response_size)
self._base_server.setup_send(response_data) self._base_server.default_ping()
self._base_server.default_ping() self._base_server.setup_send(response_data, event.stream_id)
self._base_server.default_ping()
def on_connection_lost(self, reason): def on_connection_lost(self, reason):
logging.info('Disconnect received. Ping Count %d'%self._base_server._outstanding_pings) logging.info('Disconnect received. Ping Count %d'%self._base_server._outstanding_pings)
......
...@@ -14,10 +14,10 @@ class TestcaseRstStreamAfterData(object): ...@@ -14,10 +14,10 @@ class TestcaseRstStreamAfterData(object):
def on_data_received(self, event): def on_data_received(self, event):
self._base_server.on_data_received_default(event) self._base_server.on_data_received_default(event)
sr = self._base_server.parse_received_data(self._base_server._recv_buffer) sr = self._base_server.parse_received_data(event.stream_id)
assert(sr is not None) if sr:
response_data = self._base_server.default_response_data(sr.response_size) response_data = self._base_server.default_response_data(sr.response_size)
self._ready_to_send = True self._ready_to_send = True
self._base_server.setup_send(response_data) self._base_server.setup_send(response_data, event.stream_id)
# send reset stream # send reset stream
self._base_server.send_reset_stream() self._base_server.send_reset_stream()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment