Skip to content

Commit

Permalink
Add tracking signals for getting request/response bodies. (#2767)
Browse files Browse the repository at this point in the history
* Add tracking signals for getting request/response bodies.

* Revert automatic pep8 fix.

Mark pep8 rules E225 and E226 as ignored, to prevent
automatic changes in code formating.

* Remove internal usage of Signal in favor of simple callbacks.

* Document new signals

* Move callback to a private method.

* Make check more idiomatic

* Reorder classes in __all__

* Update request lifecycle diagram to include new signals

* Don't use mutable defaults for traces. Make it private in ClientRequest

* Further updates to tracing documentation

* Polish docs

* Revert ignoring pep8 rules

* Subtle optimisation - don't create list instance if not needed

* Remove assert statement

* Add test case ensuring StreamWriter calls callback

* Add test checking that response.read() trigger trace callback

* Add CHANGES record
  • Loading branch information
kowalski authored and asvetlov committed Mar 1, 2018
1 parent 5382822 commit d95ff20
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGES/2767.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add tracking signals for getting request/response bodies.
2 changes: 1 addition & 1 deletion aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ async def _request(self, method, url, *,
response_class=self._response_class,
proxy=proxy, proxy_auth=proxy_auth, timer=timer,
session=self, auto_decompress=self._auto_decompress,
ssl=ssl, proxy_headers=proxy_headers)
ssl=ssl, proxy_headers=proxy_headers, traces=traces)

# connection timeout
try:
Expand Down
28 changes: 23 additions & 5 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def __init__(self, method, url, *,
proxy=None, proxy_auth=None,
timer=None, session=None, auto_decompress=True,
ssl=None,
proxy_headers=None):
proxy_headers=None,
traces=None):

if loop is None:
loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -209,6 +210,9 @@ def __init__(self, method, url, *,
if data or self.method not in self.GET_METHODS:
self.update_transfer_encoding()
self.update_expect_continue(expect100)
if traces is None:
traces = []
self._traces = traces

def is_ssl(self):
return self.url.scheme in ('https', 'wss')
Expand Down Expand Up @@ -475,7 +479,10 @@ async def send(self, conn):
if self.url.raw_query_string:
path += '?' + self.url.raw_query_string

writer = StreamWriter(conn.protocol, conn.transport, self.loop)
writer = StreamWriter(
conn.protocol, conn.transport, self.loop,
on_chunk_sent=self._on_chunk_request_sent
)

if self.compress:
writer.enable_compression(self.compress)
Expand Down Expand Up @@ -513,8 +520,9 @@ async def send(self, conn):
self.method, self.original_url,
writer=self._writer, continue100=self._continue, timer=self._timer,
request_info=self.request_info,
auto_decompress=self._auto_decompress)

auto_decompress=self._auto_decompress,
traces=self._traces,
)
self.response._post_init(self.loop, self._session)
return self.response

Expand All @@ -531,6 +539,10 @@ def terminate(self):
self._writer.cancel()
self._writer = None

async def _on_chunk_request_sent(self, chunk):
for trace in self._traces:
await trace.send_request_chunk_sent(chunk)


class ClientResponse(HeadersMixin):

Expand All @@ -555,7 +567,8 @@ class ClientResponse(HeadersMixin):

def __init__(self, method, url, *,
writer=None, continue100=None, timer=None,
request_info=None, auto_decompress=True):
request_info=None, auto_decompress=True,
traces=None):
assert isinstance(url, URL)

self.method = method
Expand All @@ -572,6 +585,9 @@ def __init__(self, method, url, *,
self._timer = timer if timer is not None else TimerNoop()
self._auto_decompress = auto_decompress
self._cache = {} # reqired for @reify method decorator
if traces is None:
traces = []
self._traces = traces

@property
def url(self):
Expand Down Expand Up @@ -796,6 +812,8 @@ async def read(self):
if self._content is None:
try:
self._content = await self.content.read()
for trace in self._traces:
await trace.send_response_chunk_received(self._content)
except BaseException:
self.close()
raise
Expand Down
9 changes: 7 additions & 2 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class StreamWriter(AbstractStreamWriter):

def __init__(self, protocol, transport, loop):
def __init__(self, protocol, transport, loop, on_chunk_sent=None):
self._protocol = protocol
self._transport = transport

Expand All @@ -30,6 +30,8 @@ def __init__(self, protocol, transport, loop):
self._compress = None
self._drain_waiter = None

self._on_chunk_sent = on_chunk_sent

@property
def transport(self):
return self._transport
Expand All @@ -55,13 +57,16 @@ def _write(self, chunk):
raise asyncio.CancelledError('Cannot write to closing transport')
self._transport.write(chunk)

async def write(self, chunk, *, drain=True, LIMIT=64*1024):
async def write(self, chunk, *, drain=True, LIMIT=0x10000):
"""Writes chunk of data to a stream.
write_eof() indicates end of stream.
writer can't be used after write_eof() method being called.
write() return drain future.
"""
if self._on_chunk_sent is not None:
await self._on_chunk_sent(chunk)

if self._compress is not None:
chunk = self._compress.compress(chunk)
if not chunk:
Expand Down
41 changes: 40 additions & 1 deletion aiohttp/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
'TraceConnectionCreateEndParams', 'TraceConnectionReuseconnParams',
'TraceDnsResolveHostStartParams', 'TraceDnsResolveHostEndParams',
'TraceDnsCacheHitParams', 'TraceDnsCacheMissParams',
'TraceRequestRedirectParams'
'TraceRequestRedirectParams',
'TraceRequestChunkSentParams', 'TraceResponseChunkReceivedParams',
)


Expand All @@ -25,6 +26,8 @@ class TraceConfig:

def __init__(self, trace_config_ctx_factory=SimpleNamespace):
self._on_request_start = Signal(self)
self._on_request_chunk_sent = Signal(self)
self._on_response_chunk_received = Signal(self)
self._on_request_end = Signal(self)
self._on_request_exception = Signal(self)
self._on_request_redirect = Signal(self)
Expand All @@ -47,6 +50,8 @@ def trace_config_ctx(self, trace_request_ctx=None):

def freeze(self):
self._on_request_start.freeze()
self._on_request_chunk_sent.freeze()
self._on_response_chunk_received.freeze()
self._on_request_end.freeze()
self._on_request_exception.freeze()
self._on_request_redirect.freeze()
Expand All @@ -64,6 +69,14 @@ def freeze(self):
def on_request_start(self):
return self._on_request_start

@property
def on_request_chunk_sent(self):
return self._on_request_chunk_sent

@property
def on_response_chunk_received(self):
return self._on_response_chunk_received

@property
def on_request_end(self):
return self._on_request_end
Expand Down Expand Up @@ -121,6 +134,18 @@ class TraceRequestStartParams:
headers = attr.ib(type=CIMultiDict)


@attr.s(frozen=True, slots=True)
class TraceRequestChunkSentParams:
""" Parameters sent by the `on_request_chunk_sent` signal"""
chunk = attr.ib(type=bytes)


@attr.s(frozen=True, slots=True)
class TraceResponseChunkReceivedParams:
""" Parameters sent by the `on_response_chunk_received` signal"""
chunk = attr.ib(type=bytes)


@attr.s(frozen=True, slots=True)
class TraceRequestEndParams:
""" Parameters sent by the `on_request_end` signal"""
Expand Down Expand Up @@ -213,6 +238,20 @@ async def send_request_start(self, method, url, headers):
TraceRequestStartParams(method, url, headers)
)

async def send_request_chunk_sent(self, chunk):
return await self._trace_config.on_request_chunk_sent.send(
self._session,
self._trace_config_ctx,
TraceRequestChunkSentParams(chunk)
)

async def send_response_chunk_received(self, chunk):
return await self._trace_config.on_response_chunk_received.send(
self._session,
self._trace_config_ctx,
TraceResponseChunkReceivedParams(chunk)
)

async def send_request_end(self, method, url, headers, response):
return await self._trace_config.on_request_end.send(
self._session,
Expand Down
75 changes: 67 additions & 8 deletions docs/tracing_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,26 @@ Overview
exception[shape=flowchart.terminator, description="on_request_exception"];

acquire_connection[description="Connection acquiring"];
got_response;
send_request;
headers_received;
headers_sent;
chunk_sent[description="on_request_chunk_sent"];
chunk_received[description="on_response_chunk_received"];

start -> acquire_connection;
acquire_connection -> send_request;
send_request -> got_response;
got_response -> redirect;
got_response -> end;
redirect -> send_request;
send_request -> exception;
acquire_connection -> headers_sent;
headers_sent -> headers_received;
headers_sent -> chunk_sent;
chunk_sent -> chunk_sent;
chunk_sent -> headers_received;
headers_received -> chunk_received;
chunk_received -> chunk_received;
chunk_received -> end;
headers_received -> redirect;
headers_received -> end;
redirect -> headers_sent;
chunk_received -> exception;
chunk_sent -> exception;
headers_sent -> exception;

}

Expand Down Expand Up @@ -147,6 +157,26 @@ TraceConfig

``params`` is :class:`aiohttp.TraceRequestStartParams` instance.

.. attribute:: on_request_chunk_sent


Property that gives access to the signals that will be executed
when a chunk of request body is sent.

``params`` is :class:`aiohttp.TraceRequestChunkSentParams` instance.

.. versionadded:: 3.1

.. attribute:: on_response_chunk_received


Property that gives access to the signals that will be executed
when a chunk of response body is received.

``params`` is :class:`aiohttp.TraceResponseChunkReceivedParams` instance.

.. versionadded:: 3.1

.. attribute:: on_request_redirect

Property that gives access to the signals that will be executed when a
Expand Down Expand Up @@ -259,6 +289,35 @@ TraceRequestStartParams

Headers that will be used for the request, can be mutated.


TraceRequestChunkSentParams
---------------------------

.. class:: TraceRequestChunkSentParams

.. versionadded:: 3.1

See :attr:`TraceConfig.on_request_chunk_sent` for details.

.. attribute:: chunk

Bytes of chunk sent


TraceResponseChunkSentParams
----------------------------

.. class:: TraceResponseChunkSentParams

.. versionadded:: 3.1

See :attr:`TraceConfig.on_response_chunk_received` for details.

.. attribute:: chunk

Bytes of chunk received


TraceRequestEndParams
---------------------

Expand Down
33 changes: 33 additions & 0 deletions tests/test_client_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import aiohttp
from aiohttp import http
from aiohttp.client_reqrep import ClientResponse, RequestInfo
from aiohttp.test_utils import make_mocked_coro


@pytest.fixture
Expand Down Expand Up @@ -613,3 +614,35 @@ def test_redirect_history_in_exception():
with pytest.raises(aiohttp.ClientResponseError) as cm:
response.raise_for_status()
assert [hist_response] == cm.value.history


async def test_response_read_triggers_callback(loop, session):
trace = mock.Mock()
trace.send_response_chunk_received = make_mocked_coro()
response_body = b'This is response'

response = ClientResponse(
'get', URL('http://def-cl-resp.org'),
traces=[trace]
)
response._post_init(loop, session)

def side_effect(*args, **kwargs):
fut = loop.create_future()
fut.set_result(response_body)
return fut

response.headers = {
'Content-Type': 'application/json;charset=cp1251'}
content = response.content = mock.Mock()
content.read.side_effect = side_effect

res = await response.read()
assert res == response_body
assert response._connection is None

assert trace.send_response_chunk_received.called
assert (
trace.send_response_chunk_received.call_args ==
mock.call(response_body)
)
Loading

0 comments on commit d95ff20

Please sign in to comment.