Skip to content

Commit

Permalink
fix: preserve errno if all exceptions have the same errno (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Aug 7, 2024
1 parent 56e7ba6 commit 7bbb2bf
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 12 deletions.
28 changes: 18 additions & 10 deletions src/aiohappyeyeballs/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def start_connection(
addr_infos = _interleave_addrinfos(addr_infos, interleave)

sock: Optional[socket.socket] = None
exceptions: List[List[Exception]] = []
exceptions: List[List[OSError]] = []
if happy_eyeballs_delay is None or single_addr_info:
# not using happy eyeballs
for addrinfo in addr_infos:
Expand All @@ -99,20 +99,28 @@ async def start_connection(
if sock is None:
all_exceptions = [exc for sub in exceptions for exc in sub]
try:
first_exception = all_exceptions[0]
if len(all_exceptions) == 1:
raise all_exceptions[0]
raise first_exception
else:
# If they all have the same str(), raise one.
model = str(all_exceptions[0])
model = str(first_exception)
if all(str(exc) == model for exc in all_exceptions):
raise all_exceptions[0]
raise first_exception
# Raise a combined exception so the user can see all
# the various error messages.
raise OSError(
"Multiple exceptions: {}".format(
", ".join(str(exc) for exc in all_exceptions)
)
msg = "Multiple exceptions: {}".format(
", ".join(str(exc) for exc in all_exceptions)
)
# If the errno is the same for all exceptions, raise
# an OSError with that errno.
first_errno = first_exception.errno
if all(
isinstance(exc, OSError) and exc.errno == first_errno
for exc in all_exceptions
):
raise OSError(first_errno, msg)
raise OSError(msg)
finally:
all_exceptions = None # type: ignore[assignment]
exceptions = None # type: ignore[assignment]
Expand All @@ -122,12 +130,12 @@ async def start_connection(

async def _connect_sock(
loop: asyncio.AbstractEventLoop,
exceptions: List[List[Exception]],
exceptions: List[List[OSError]],
addr_info: AddrInfoType,
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
) -> socket.socket:
"""Create, bind and connect one socket."""
my_exceptions: list[Exception] = []
my_exceptions: list[OSError] = []
exceptions.append(my_exceptions)
family, type_, proto, _, address = addr_info
sock = None
Expand Down
97 changes: 95 additions & 2 deletions tests/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,9 +1191,10 @@ async def _sock_connect(

@patch_socket
@pytest.mark.asyncio
async def test_all_same_exception(
async def test_all_same_exception_and_same_errno(
m_socket: ModuleType,
) -> None:
"""Test that all exceptions are the same and have the same errno."""
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
Expand Down Expand Up @@ -1256,7 +1257,96 @@ async def _sock_connect(
# We should get the same exception raised if they are all the same
with mock.patch.object(loop, "sock_connect", _sock_connect), pytest.raises(
OSError, match="all fail"
):
) as exc_info:
assert (
await start_connection(
addr_info,
happy_eyeballs_delay=0.3,
interleave=2,
local_addr_infos=local_addr_infos,
)
== mock_socket
)

assert exc_info.value.errno == 5

# All calls failed
assert create_calls == [
("dead:beef::", 80, 0, 0),
("dead:aaaa::", 80, 0, 0),
("107.6.106.83", 80),
]


@patch_socket
@pytest.mark.asyncio
async def test_all_same_exception_and_with_different_errno(
m_socket: ModuleType,
) -> None:
"""Test no errno is set if all OSError have different errno."""
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)
create_calls = []

def _socket(*args, **kw):
for attr in kw:
setattr(mock_socket, attr, kw[attr])
return mock_socket

async def _sock_connect(
sock: socket.socket, address: Tuple[str, int, int, int]
) -> None:
create_calls.append(address)
raise OSError(len(create_calls), "all fail")

m_socket.socket = _socket # type: ignore
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
local_addr_infos = [
(
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("::1", 0, 0, 0),
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("127.0.0.1", 0),
),
]
loop = asyncio.get_running_loop()
# We should get the same exception raised if they are all the same
with mock.patch.object(loop, "sock_connect", _sock_connect), pytest.raises(
OSError, match="all fail"
) as exc_info:
assert (
await start_connection(
addr_info,
Expand All @@ -1267,6 +1357,9 @@ async def _sock_connect(
== mock_socket
)

# No errno is set if they are all different
assert exc_info.value.errno is None

# All calls failed
assert create_calls == [
("dead:beef::", 80, 0, 0),
Expand Down

0 comments on commit 7bbb2bf

Please sign in to comment.