From ef210a893c919af7309d0a93e8d69dba9b4501bf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 30 Sep 2024 14:02:15 -0500 Subject: [PATCH] fix: handle uvloop raising RuntimeError when connecting a socket fixes #93 --- src/aiohappyeyeballs/impl.py | 26 ++++----- tests/test_impl.py | 101 +++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 12 deletions(-) diff --git a/src/aiohappyeyeballs/impl.py b/src/aiohappyeyeballs/impl.py index 1017e82..f1efeda 100644 --- a/src/aiohappyeyeballs/impl.py +++ b/src/aiohappyeyeballs/impl.py @@ -6,7 +6,7 @@ import itertools import socket import sys -from typing import List, Optional, Sequence +from typing import List, Optional, Sequence, Union from . import staggered from .types import AddrInfoType @@ -73,7 +73,8 @@ async def start_connection( addr_infos = _interleave_addrinfos(addr_infos, interleave) sock: Optional[socket.socket] = None - exceptions: List[List[OSError]] = [] + # uvloop can raise RuntimeError instead of OSError + exceptions: List[List[Union[OSError, RuntimeError]]] = [] if happy_eyeballs_delay is None or single_addr_info: # not using happy eyeballs for addrinfo in addr_infos: @@ -82,7 +83,7 @@ async def start_connection( current_loop, exceptions, addrinfo, local_addr_infos ) break - except OSError: + except (RuntimeError, OSError): continue else: # using happy eyeballs sock, _, _ = await staggered.staggered_race( @@ -113,12 +114,13 @@ async def start_connection( ) # If the errno is the same for all exceptions, raise # an OSError with that errno. - first_errno = first_exception.errno - if all( - isinstance(exc, OSError) and exc.errno == first_errno - for exc in all_exceptions - ): - raise OSError(first_errno, msg) + if isinstance(first_exception, OSError): + first_errno = first_exception.errno + if all( + isinstance(exc, OSError) and exc.errno == first_errno + for exc in all_exceptions + ): + raise OSError(first_errno, msg) raise OSError(msg) finally: all_exceptions = None # type: ignore[assignment] @@ -129,12 +131,12 @@ async def start_connection( async def _connect_sock( loop: asyncio.AbstractEventLoop, - exceptions: List[List[OSError]], + exceptions: List[List[Union[OSError, RuntimeError]]], addr_info: AddrInfoType, local_addr_infos: Optional[Sequence[AddrInfoType]] = None, ) -> socket.socket: """Create, bind and connect one socket.""" - my_exceptions: list[OSError] = [] + my_exceptions: List[Union[OSError, RuntimeError]] = [] exceptions.append(my_exceptions) family, type_, proto, _, address = addr_info sock = None @@ -164,7 +166,7 @@ async def _connect_sock( raise OSError(f"no matching local address with {family=} found") await loop.sock_connect(sock, address) return sock - except OSError as exc: + except (RuntimeError, OSError) as exc: my_exceptions.append(exc) if sock is not None: sock.close() diff --git a/tests/test_impl.py b/tests/test_impl.py index 2eb384f..81af241 100644 --- a/tests/test_impl.py +++ b/tests/test_impl.py @@ -1368,6 +1368,107 @@ async def _sock_connect( ] +@patch_socket +@pytest.mark.asyncio +async def test_handling_uvloop_runtime_error( + m_socket: ModuleType, +) -> None: + """ + Test RuntimeError is handled when connecting a socket with uvloop. + + Connecting a socket can raise a RuntimeError, OSError or ValueError. + + - OSError: If the address is invalid or the connection fails. + - ValueError: if a non-sock it passed (this should never happen). + https://github.com/python/cpython/blob/e44eebfc1eccdaaebc219accbfc705c9a9de068d/Lib/asyncio/selector_events.py#L271 + - RuntimeError: If the file descriptor is already in use by a transport. + + We should never get ValueError since we are using the correct types. + + selector_events.py never seems to raise a RuntimeError, but it is possible + with uvloop. This test is to ensure that we handle it correctly. + """ + mock_socket = mock.MagicMock( + family=socket.AF_INET, + type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP, + fileno=mock.MagicMock(return_value=1), + ) + create_calls = [] + + def _socket(*args, **kw): + for attr in kw: + setattr(mock_socket, attr, kw[attr]) + return mock_socket + + async def _sock_connect( + sock: socket.socket, address: Tuple[str, int, int, int] + ) -> None: + create_calls.append(address) + raise RuntimeError("all fail") + + m_socket.socket = _socket # type: ignore + ipv6_addr_info = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:beef::", 80, 0, 0), + ) + ipv6_addr_info_2 = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:aaaa::", 80, 0, 0), + ) + ipv4_addr_info = ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("107.6.106.83", 80), + ) + addr_info = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info] + local_addr_infos = [ + ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("::1", 0, 0, 0), + ), + ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("127.0.0.1", 0), + ), + ] + loop = asyncio.get_running_loop() + # We should get the same exception raised if they are all the same + with mock.patch.object(loop, "sock_connect", _sock_connect), pytest.raises( + RuntimeError, match="all fail" + ): + assert ( + await start_connection( + addr_info, + happy_eyeballs_delay=0.3, + interleave=2, + local_addr_infos=local_addr_infos, + ) + == mock_socket + ) + + # All calls failed + assert create_calls == [ + ("dead:beef::", 80, 0, 0), + ("dead:aaaa::", 80, 0, 0), + ("107.6.106.83", 80), + ] + + @patch_socket @pytest.mark.asyncio @pytest.mark.xfail(reason="raises RuntimeError: coroutine ignored GeneratorExit")