Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: copy staggered from standard lib for python 3.12+ #95

Merged
merged 7 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/aiohappyeyeballs/_staggered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import asyncio
import contextlib
from typing import Awaitable, Callable, Iterable, List, Optional, Tuple, TypeVar


class _Done(Exception):
pass


_T = TypeVar("_T")


async def staggered_race(
coro_fns: Iterable[Callable[[], Awaitable[_T]]], delay: Optional[float]
) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]:
"""
Run coroutines with staggered start times and take the first to finish.

This method takes an iterable of coroutine functions. The first one is
started immediately. From then on, whenever the immediately preceding one
fails (raises an exception), or when *delay* seconds has passed, the next
coroutine is started. This continues until one of the coroutines complete
successfully, in which case all others are cancelled, or until all
coroutines fail.

The coroutines provided should be well-behaved in the following way:

* They should only ``return`` if completed successfully.

* They should always raise an exception if they did not complete
successfully. In particular, if they handle cancellation, they should
probably reraise, like this::

try:
# do work
except asyncio.CancelledError:
# undo partially completed work
raise

Args:
coro_fns: an iterable of coroutine functions, i.e. callables that
return a coroutine object when called. Use ``functools.partial`` or
lambdas to pass arguments.

delay: amount of time, in seconds, between starting coroutines. If
``None``, the coroutines will run sequentially.

Returns:
tuple *(winner_result, winner_index, exceptions)* where

- *winner_result*: the result of the winning coroutine, or ``None``
if no coroutines won.

- *winner_index*: the index of the winning coroutine in
``coro_fns``, or ``None`` if no coroutines won. If the winning
coroutine may return None on success, *winner_index* can be used
to definitively determine whether any coroutine won.

- *exceptions*: list of exceptions returned by the coroutines.
``len(exceptions)`` is equal to the number of coroutines actually
started, and the order is the same as in ``coro_fns``. The winning
coroutine's entry is ``None``.

"""
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
winner_result = None
winner_index = None
exceptions: List[Optional[BaseException]] = []

async def run_one_coro(
this_index: int,
coro_fn: Callable[[], Awaitable[_T]],
this_failed: asyncio.Event,
) -> None:
try:
result = await coro_fn()
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as e:
exceptions[this_index] = e
this_failed.set() # Kickstart the next coroutine
else:
# Store winner's results
nonlocal winner_index, winner_result
assert winner_index is None # noqa: S101
winner_index = this_index
winner_result = result
raise _Done

try:
async with asyncio.TaskGroup() as tg:
for this_index, coro_fn in enumerate(coro_fns):
this_failed = asyncio.Event()
exceptions.append(None)
tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
with contextlib.suppress(TimeoutError):
await asyncio.wait_for(this_failed.wait(), delay)
except* _Done:
pass

return winner_result, winner_index, exceptions
2 changes: 1 addition & 1 deletion src/aiohappyeyeballs/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import itertools
import socket
import sys
from asyncio import staggered
from typing import List, Optional, Sequence

from . import staggered
from .types import AddrInfoType

if sys.version_info < (3, 8, 2): # noqa: UP036
Expand Down
9 changes: 9 additions & 0 deletions src/aiohappyeyeballs/staggered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import sys

if sys.version_info > (3, 11):
# https://github.com/python/cpython/issues/124639#issuecomment-2378129834
from ._staggered import staggered_race
else:
from asyncio.staggered import staggered_race

Check warning on line 7 in src/aiohappyeyeballs/staggered.py

View check run for this annotation

Codecov / codecov/patch

src/aiohappyeyeballs/staggered.py#L7

Added line #L7 was not covered by tests

__all__ = ["staggered_race"]
82 changes: 82 additions & 0 deletions tests/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,88 @@ async def _sock_connect(
]


@patch_socket
@pytest.mark.asyncio
@pytest.mark.xfail(reason="raises RuntimeError: coroutine ignored GeneratorExit")
async def test_handling_system_exit(
m_socket: ModuleType,
) -> None:
"""Test handling SystemExit."""
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)
create_calls = []

def _socket(*args, **kw):
for attr in kw:
setattr(mock_socket, attr, kw[attr])
return mock_socket

async def _sock_connect(
sock: socket.socket, address: Tuple[str, int, int, int]
) -> None:
create_calls.append(address)
raise SystemExit

m_socket.socket = _socket # type: ignore
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
local_addr_infos = [
(
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("::1", 0, 0, 0),
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("127.0.0.1", 0),
),
]
loop = asyncio.get_running_loop()
with pytest.raises(SystemExit), mock.patch.object(
loop, "sock_connect", _sock_connect
):
await start_connection(
addr_info,
happy_eyeballs_delay=0.3,
interleave=2,
local_addr_infos=local_addr_infos,
)

# Stopped after the first call
assert create_calls == [
("dead:beef::", 80, 0, 0),
]


@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info >= (3, 8, 2), reason="requires < python 3.8.2")
def test_python_38_compat() -> None:
Expand Down
Loading