Skip to content

Commit

Permalink
feat: optimize for single case (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Dec 9, 2023
1 parent a77a7f3 commit c7d72f3
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/aiohappyeyeballs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__version__ = "0.2.0"

from .impl import create_connection
from .impl import AddrInfoType, create_connection

__all__ = ("create_connection",)
__all__ = ("create_connection", "AddrInfoType")
10 changes: 5 additions & 5 deletions src/aiohappyeyeballs/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
import itertools
import socket
from asyncio import staggered
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union

AddrInfoType = Tuple[
int, int, int, str, Union[Tuple[str, int], Tuple[str, int, int, int]]
]


async def create_connection(
addr_infos: List[AddrInfoType],
addr_infos: Sequence[AddrInfoType],
*,
local_addr_infos: Optional[List[AddrInfoType]] = None,
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
happy_eyeballs_delay: Optional[float] = None,
interleave: Optional[int] = None,
all_errors: bool = False,
Expand Down Expand Up @@ -97,7 +97,7 @@ async def _connect_sock(
loop: asyncio.AbstractEventLoop,
exceptions: List[List[Exception]],
addr_info: AddrInfoType,
local_addr_infos: Optional[List[AddrInfoType]] = None,
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
) -> socket.socket:
"""Create, bind and connect one socket."""
my_exceptions: list[Exception] = []
Expand Down Expand Up @@ -144,7 +144,7 @@ async def _connect_sock(


def _interleave_addrinfos(
addrinfos: List[AddrInfoType], first_address_family_count: int = 1
addrinfos: Sequence[AddrInfoType], first_address_family_count: int = 1
) -> List[AddrInfoType]:
"""Interleave list of addrinfo tuples by family."""
# Group addresses by family
Expand Down
58 changes: 58 additions & 0 deletions tests/test_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import socket
from test.test_asyncio import utils as test_utils
from types import ModuleType
from unittest import mock

import pytest

from aiohappyeyeballs import create_connection

MOCK_ANY = mock.ANY


def mock_socket_module():
m_socket = mock.MagicMock(spec=socket)
for name in (
"AF_INET",
"AF_INET6",
"AF_UNSPEC",
"IPPROTO_TCP",
"IPPROTO_UDP",
"SOCK_STREAM",
"SOCK_DGRAM",
"SOL_SOCKET",
"SO_REUSEADDR",
"inet_pton",
):
if hasattr(socket, name):
setattr(m_socket, name, getattr(socket, name))
else:
delattr(m_socket, name)

m_socket.socket = mock.MagicMock()
m_socket.socket.return_value = test_utils.mock_nonblocking_socket()

return m_socket


def patch_socket(f):
return mock.patch("aiohappyeyeballs.impl.socket", new_callable=mock_socket_module)(
f
)


@pytest.mark.asyncio
@patch_socket
async def test_create_connection_single_addr_info_errors(m_socket: ModuleType) -> None:
idx = -1
errors = ["err1", "err2"]

def _socket(*args, **kw):
nonlocal idx, errors
idx += 1
raise OSError(errors[idx])

m_socket.socket = _socket # type: ignore
addr_info = [(2, 1, 6, "", ("107.6.106.82", 80))]
with pytest.raises(OSError, match=errors[0]):
await create_connection(addr_info)

0 comments on commit c7d72f3

Please sign in to comment.