Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: re-raise RuntimeError when uvloop raises RuntimeError during connect #105

Merged
merged 3 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions src/aiohappyeyeballs/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import itertools
import socket
import sys
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Union

from . import staggered
from .types import AddrInfoType
Expand Down Expand Up @@ -73,7 +73,8 @@ async def start_connection(
addr_infos = _interleave_addrinfos(addr_infos, interleave)

sock: Optional[socket.socket] = None
exceptions: List[List[OSError]] = []
# uvloop can raise RuntimeError instead of OSError
exceptions: List[List[Union[OSError, RuntimeError]]] = []
if happy_eyeballs_delay is None or single_addr_info:
# not using happy eyeballs
for addrinfo in addr_infos:
Expand All @@ -82,7 +83,7 @@ async def start_connection(
current_loop, exceptions, addrinfo, local_addr_infos
)
break
except OSError:
except (RuntimeError, OSError):
continue
else: # using happy eyeballs
sock, _, _ = await staggered.staggered_race(
Expand Down Expand Up @@ -113,12 +114,20 @@ async def start_connection(
)
# If the errno is the same for all exceptions, raise
# an OSError with that errno.
first_errno = first_exception.errno
if all(
isinstance(exc, OSError) and exc.errno == first_errno
for exc in all_exceptions
if isinstance(first_exception, OSError):
first_errno = first_exception.errno
if all(
isinstance(exc, OSError) and exc.errno == first_errno
for exc in all_exceptions
):
raise OSError(first_errno, msg)
elif isinstance(first_exception, RuntimeError) and all(
isinstance(exc, RuntimeError) for exc in all_exceptions
):
raise OSError(first_errno, msg)
raise RuntimeError(msg)
# We have a mix of OSError and RuntimeError
# so we have to pick which one to raise.
# and we raise OSError for compatibility
raise OSError(msg)
finally:
all_exceptions = None # type: ignore[assignment]
Expand All @@ -129,12 +138,12 @@ async def start_connection(

async def _connect_sock(
loop: asyncio.AbstractEventLoop,
exceptions: List[List[OSError]],
exceptions: List[List[Union[OSError, RuntimeError]]],
addr_info: AddrInfoType,
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
) -> socket.socket:
"""Create, bind and connect one socket."""
my_exceptions: list[OSError] = []
my_exceptions: List[Union[OSError, RuntimeError]] = []
exceptions.append(my_exceptions)
family, type_, proto, _, address = addr_info
sock = None
Expand Down Expand Up @@ -164,7 +173,7 @@ async def _connect_sock(
raise OSError(f"no matching local address with {family=} found")
await loop.sock_connect(sock, address)
return sock
except OSError as exc:
except (RuntimeError, OSError) as exc:
my_exceptions.append(exc)
if sock is not None:
sock.close()
Expand Down
283 changes: 283 additions & 0 deletions tests/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,289 @@ async def _sock_connect(
]


@patch_socket
@pytest.mark.asyncio
async def test_uvloop_runtime_error(
m_socket: ModuleType,
) -> None:
"""
Test RuntimeError is handled when connecting a socket with uvloop.

Connecting a socket can raise a RuntimeError, OSError or ValueError.

- OSError: If the address is invalid or the connection fails.
- ValueError: if a non-sock it passed (this should never happen).
https://github.com/python/cpython/blob/e44eebfc1eccdaaebc219accbfc705c9a9de068d/Lib/asyncio/selector_events.py#L271
- RuntimeError: If the file descriptor is already in use by a transport.

We should never get ValueError since we are using the correct types.

selector_events.py never seems to raise a RuntimeError, but it is possible
with uvloop. This test is to ensure that we handle it correctly.
"""
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)
create_calls = []

def _socket(*args, **kw):
for attr in kw:
setattr(mock_socket, attr, kw[attr])
return mock_socket

async def _sock_connect(
sock: socket.socket, address: Tuple[str, int, int, int]
) -> None:
create_calls.append(address)
raise RuntimeError("all fail")

m_socket.socket = _socket # type: ignore
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
local_addr_infos = [
(
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("::1", 0, 0, 0),
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("127.0.0.1", 0),
),
]
loop = asyncio.get_running_loop()
# We should get the same exception raised if they are all the same
with mock.patch.object(loop, "sock_connect", _sock_connect), pytest.raises(
RuntimeError, match="all fail"
):
assert (
await start_connection(
addr_info,
happy_eyeballs_delay=0.3,
interleave=2,
local_addr_infos=local_addr_infos,
)
== mock_socket
)

# All calls failed
assert create_calls == [
("dead:beef::", 80, 0, 0),
("dead:aaaa::", 80, 0, 0),
("107.6.106.83", 80),
]


@patch_socket
@pytest.mark.asyncio
async def test_uvloop_different_runtime_error(
m_socket: ModuleType,
) -> None:
"""Test different RuntimeErrors are handled when connecting a socket with uvloop."""
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)
create_calls = []
counter = 0

def _socket(*args, **kw):
for attr in kw:
setattr(mock_socket, attr, kw[attr])
return mock_socket

async def _sock_connect(
sock: socket.socket, address: Tuple[str, int, int, int]
) -> None:
create_calls.append(address)
nonlocal counter
counter += 1
raise RuntimeError(counter)

m_socket.socket = _socket # type: ignore
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
local_addr_infos = [
(
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("::1", 0, 0, 0),
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("127.0.0.1", 0),
),
]
loop = asyncio.get_running_loop()
# We should get the same exception raised if they are all the same
with mock.patch.object(loop, "sock_connect", _sock_connect), pytest.raises(
RuntimeError, match="Multiple exceptions: 1, 2, 3"
):
assert (
await start_connection(
addr_info,
happy_eyeballs_delay=0.3,
interleave=2,
local_addr_infos=local_addr_infos,
)
== mock_socket
)

# All calls failed
assert create_calls == [
("dead:beef::", 80, 0, 0),
("dead:aaaa::", 80, 0, 0),
("107.6.106.83", 80),
]


@patch_socket
@pytest.mark.asyncio
async def test_uvloop_mixing_os_and_runtime_error(
m_socket: ModuleType,
) -> None:
"""Test uvloop raising OSError and RuntimeError."""
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)
create_calls = []
counter = 0

def _socket(*args, **kw):
for attr in kw:
setattr(mock_socket, attr, kw[attr])
return mock_socket

async def _sock_connect(
sock: socket.socket, address: Tuple[str, int, int, int]
) -> None:
create_calls.append(address)
nonlocal counter
counter += 1
if counter == 1:
raise RuntimeError(counter)
raise OSError(counter, f"all fail {counter}")

m_socket.socket = _socket # type: ignore
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
local_addr_infos = [
(
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("::1", 0, 0, 0),
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("127.0.0.1", 0),
),
]
loop = asyncio.get_running_loop()
# We should get the same exception raised if they are all the same
with mock.patch.object(loop, "sock_connect", _sock_connect), pytest.raises(
OSError, match="Multiple exceptions: 1"
):
assert (
await start_connection(
addr_info,
happy_eyeballs_delay=0.3,
interleave=2,
local_addr_infos=local_addr_infos,
)
== mock_socket
)

# All calls failed
assert create_calls == [
("dead:beef::", 80, 0, 0),
("dead:aaaa::", 80, 0, 0),
("107.6.106.83", 80),
]


@patch_socket
@pytest.mark.asyncio
@pytest.mark.xfail(reason="raises RuntimeError: coroutine ignored GeneratorExit")
Expand Down
Loading