Skip to content

Commit

Permalink
Fix __aiter__
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Sep 10, 2016
1 parent 34c3647 commit 70f8a1d
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 63 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ language: python
python:
- 3.4
- 3.5
- 3.5-dev # for 3.5.2
# - nightly

os:
Expand Down
6 changes: 4 additions & 2 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import sys

from ._ws_impl import CLOSED_MESSAGE, WebSocketError, WSMessage, WSMsgType
from .helpers import _decorate_aiter

PY_35 = sys.version_info >= (3, 5)
PY_352 = sys.version_info >= (3, 5, 2)


class ClientWebSocketResponse:
Expand Down Expand Up @@ -179,10 +179,12 @@ def receive_json(self, *, loads=json.loads):
return loads(data)

if PY_35:
@_decorate_aiter
def __aiter__(self):
return self

if not PY_352:
__aiter__ = asyncio.coroutine(__aiter__)

@asyncio.coroutine
def __anext__(self):
msg = yield from self.receive()
Expand Down
10 changes: 0 additions & 10 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import io
import os
import re
import sys
import warnings
from collections import namedtuple
from http.cookies import Morsel, SimpleCookie
Expand All @@ -32,8 +31,6 @@
'Timeout', 'CookieJar', 'ensure_future')


PY_352 = sys.version_info >= (3, 5, 2)

sentinel = object()


Expand Down Expand Up @@ -837,10 +834,3 @@ def _get_kwarg(kwargs, old, new, value):
return val
else:
return value


def _decorate_aiter(coro): # pragma: no cover
if PY_352:
return coro
else:
return asyncio.coroutine(coro)
69 changes: 41 additions & 28 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import mimetypes
import os
import re
import sys
import uuid
import warnings
import zlib
Expand All @@ -17,7 +18,7 @@

from .hdrs import (CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LENGTH,
CONTENT_TRANSFER_ENCODING, CONTENT_TYPE)
from .helpers import _decorate_aiter, parse_mimetype
from .helpers import parse_mimetype
from .protocol import HttpParser

__all__ = ('MultipartReader', 'MultipartWriter',
Expand All @@ -32,6 +33,9 @@
'?', '=', '{', '}', ' ', chr(9)}
TOKEN = CHAR ^ CTL ^ SEPARATORS

PY_35 = sys.version_info >= (3, 5)
PY_352 = sys.version_info >= (3, 5, 2)


class BadContentDispositionHeader(RuntimeWarning):
pass
Expand Down Expand Up @@ -161,16 +165,19 @@ def __init__(self, resp, stream):
self.resp = resp
self.stream = stream

@_decorate_aiter
def __aiter__(self):
return self
if PY_35:
def __aiter__(self):
return self

@asyncio.coroutine
def __anext__(self):
part = yield from self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part
if not PY_352:
__aiter__ = asyncio.coroutine(__aiter__)

@asyncio.coroutine
def __anext__(self):
part = yield from self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part

def at_eof(self):
"""Returns ``True`` when all response data had been read.
Expand Down Expand Up @@ -211,16 +218,19 @@ def __init__(self, boundary, headers, content):
self._prev_chunk = None
self._content_eof = 0

@_decorate_aiter
def __aiter__(self):
return self
if PY_35:
def __aiter__(self):
return self

@asyncio.coroutine
def __anext__(self):
part = yield from self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part
if not PY_352:
__aiter__ = asyncio.coroutine(__aiter__)

@asyncio.coroutine
def __anext__(self):
part = yield from self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part

@asyncio.coroutine
def next(self):
Expand Down Expand Up @@ -507,16 +517,19 @@ def __init__(self, headers, content):
self._at_bof = True
self._unread = []

@_decorate_aiter
def __aiter__(self):
return self
if PY_35:
def __aiter__(self):
return self

@asyncio.coroutine
def __anext__(self):
part = yield from self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part
if not PY_352:
__aiter__ = asyncio.coroutine(__aiter__)

@asyncio.coroutine
def __anext__(self):
part = yield from self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part

@classmethod
def from_response(cls, response):
Expand Down
42 changes: 25 additions & 17 deletions aiohttp/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
'FlowControlDataQueue', 'FlowControlChunksQueue')

PY_35 = sys.version_info >= (3, 5)
PY_352 = sys.version_info >= (3, 5, 2)

EOF_MARKER = b''
DEFAULT_LIMIT = 2 ** 16
Expand All @@ -22,33 +23,38 @@ class EofStream(Exception):
"""eof stream indication."""


class AsyncStreamIterator:
if PY_35:
class AsyncStreamIterator:

def __init__(self, read_func):
self.read_func = read_func
def __init__(self, read_func):
self.read_func = read_func

@helpers._decorate_aiter
def __aiter__(self):
return self
def __aiter__(self):
return self

@asyncio.coroutine
def __anext__(self):
try:
rv = yield from self.read_func()
except EofStream:
raise StopAsyncIteration # NOQA
if rv == EOF_MARKER:
raise StopAsyncIteration # NOQA
return rv
if not PY_352:
__aiter__ = asyncio.coroutine(__aiter__)

@asyncio.coroutine
def __anext__(self):
try:
rv = yield from self.read_func()
except EofStream:
raise StopAsyncIteration # NOQA
if rv == EOF_MARKER:
raise StopAsyncIteration # NOQA
return rv


class AsyncStreamReaderMixin:

if PY_35:
@helpers._decorate_aiter
def __aiter__(self):
return AsyncStreamIterator(self.readline)

if not PY_352:
__aiter__ = asyncio.coroutine(__aiter__)

def iter_chunked(self, n):
"""Returns an asynchronous iterator that yields chunks of size n.
Expand Down Expand Up @@ -470,10 +476,12 @@ def read(self):
raise EofStream

if PY_35:
@helpers._decorate_aiter
def __aiter__(self):
return AsyncStreamIterator(self.read)

if not PY_352:
__aiter__ = asyncio.coroutine(__aiter__)


class ChunksQueue(DataQueue):
"""Like a :class:`DataQueue`, but for binary chunked data transfer."""
Expand Down
6 changes: 4 additions & 2 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from ._ws_impl import (CLOSED_MESSAGE, WebSocketError, WSMessage, WSMsgType,
do_handshake)
from .errors import ClientDisconnectedError, HttpProcessingError
from .helpers import _decorate_aiter
from .web_exceptions import (HTTPBadRequest, HTTPInternalServerError,
HTTPMethodNotAllowed)
from .web_reqrep import StreamResponse

__all__ = ('WebSocketResponse',)

PY_35 = sys.version_info >= (3, 5)
PY_352 = sys.version_info >= (3, 5, 2)

THRESHOLD_CONNLOST_ACCESS = 5

Expand Down Expand Up @@ -302,10 +302,12 @@ def write(self, data):
raise RuntimeError("Cannot call .write() for websocket")

if PY_35:
@_decorate_aiter
def __aiter__(self):
return self

if not PY_352:
__aiter__ = asyncio.coroutine(__aiter__)

@asyncio.coroutine
def __anext__(self):
msg = yield from self.receive()
Expand Down
9 changes: 5 additions & 4 deletions tests/test_py35/test_streams_35.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ async def test_stream_reader_chunks_complete(loop):
(i.e. the data is divisible by the chunk size)
"""
chunk_iter = chunkify(DATA, 9)
async for line in create_stream(loop).iter_chunked(9):
assert line == next(chunk_iter, None)
async for data in create_stream(loop).iter_chunked(9):
assert data == next(chunk_iter, None)
pytest.raises(StopIteration, next, chunk_iter)


async def test_stream_reader_chunks_incomplete(loop):
"""Tests if chunked iteration works if the last chunk is incomplete"""
chunk_iter = chunkify(DATA, 8)
async for line in create_stream(loop).iter_chunked(8):
assert line == next(chunk_iter, None)
async for data in create_stream(loop).iter_chunked(8):
assert data == next(chunk_iter, None)
pytest.raises(StopIteration, next, chunk_iter)


Expand Down Expand Up @@ -72,6 +72,7 @@ async def test_stream_reader_iter_any(loop):
assert raw == next(it)
pytest.raises(StopIteration, next, it)


async def test_stream_reader_iter(loop):
it = iter([b'line1\n',
b'line2\n',
Expand Down

0 comments on commit 70f8a1d

Please sign in to comment.