diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 625c4dc682b..5214b60aac6 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -165,6 +165,17 @@ def __init__(self, resp, stream): self.resp = resp self.stream = stream + @asyncio.coroutine + def __aiter__(self): + return self + + @asyncio.coroutine + def __anext__(self): + part = yield from self.next() + if part is None: + raise StopAsyncIteration # NOQA + return part + def at_eof(self): """Returns ``True`` when all response data had been read. @@ -202,6 +213,17 @@ def __init__(self, boundary, headers, content): self._read_bytes = 0 self._unread = deque() + @asyncio.coroutine + def __aiter__(self): + return self + + @asyncio.coroutine + def __anext__(self): + part = yield from self.next() + if part is None: + raise StopAsyncIteration # NOQA + return part + @asyncio.coroutine def next(self): item = yield from self.read() @@ -430,6 +452,17 @@ def __init__(self, headers, content): self._at_eof = False self._unread = [] + @asyncio.coroutine + def __aiter__(self): + return self + + @asyncio.coroutine + def __anext__(self): + part = yield from self.next() + if part is None: + raise StopAsyncIteration # NOQA + return part + @classmethod def from_response(cls, response): """Constructs reader instance from HTTP response. diff --git a/tests/test_py35/test_multipart_35.py b/tests/test_py35/test_multipart_35.py new file mode 100644 index 00000000000..9d73b2eae48 --- /dev/null +++ b/tests/test_py35/test_multipart_35.py @@ -0,0 +1,68 @@ +import aiohttp +import aiohttp.hdrs as h +import io +import json +import pytest + + +class Stream(object): + + def __init__(self, content): + self.content = io.BytesIO(content) + + async def read(self, size=None): + return self.content.read(size) + + async def readline(self): + return self.content.readline() + + +@pytest.mark.run_loop +async def test_async_for_reader(loop): + data = [{"test": "passed"}, 42, b'plain text', b'aiohttp\n'] + reader = aiohttp.MultipartReader( + headers={h.CONTENT_TYPE: 'multipart/mixed; boundary=":"'}, + content=Stream(b'\r\n'.join([ + b'--:', + b'Content-Type: application/json', + b'', + json.dumps(data[0]).encode(), + b'--:', + b'Content-Type: application/json', + b'', + json.dumps(data[1]).encode(), + b'--:', + b'Content-Type: multipart/related; boundary="::"', + b'', + b'--::', + b'Content-Type: text/plain', + b'', + data[2], + b'--::', + b'Content-Disposition: attachment; filename="aiohttp"', + b'Content-Type: text/plain', + b'Content-Length: 28', + b'Content-Encoding: gzip', + b'', + b'\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03K\xcc\xcc\xcf())' + b'\xe0\x02\x00\xd6\x90\xe2O\x08\x00\x00\x00', + b'--::--', + b'--:--', + b'']))) + idata = iter(data) + async for part in reader: + if isinstance(part, aiohttp.BodyPartReader): + assert next(idata) == (await part.json()) + else: + async for subpart in part: + assert next(idata) == await subpart.read(decode=True) + + +@pytest.mark.run_loop +async def test_async_for_bodypart(loop): + part = aiohttp.BodyPartReader( + boundary=b'--:', + headers={}, + content=Stream(b'foobarbaz\r\n--:--')) + async for data in part: + assert data == b'foobarbaz'