diff --git a/src/aiohappyeyeballs/utils.py b/src/aiohappyeyeballs/utils.py index 696eb48..d40c644 100644 --- a/src/aiohappyeyeballs/utils.py +++ b/src/aiohappyeyeballs/utils.py @@ -28,19 +28,19 @@ def pop_addr_infos_interleave(addr_infos: List[AddrInfoType], interleave: int) - def remove_addr_infos( addr_infos: List[AddrInfoType], - addr: str, + address: str, ) -> None: - """Pop addr_info from the list of addr_infos by addr.""" + """Remove an address from the list of addr_infos.""" bad_addrs_infos: List[AddrInfoType] = [] for addr_info in addr_infos: - if addr_info[-1][0] == addr: + if addr_info[-1][0] == address: bad_addrs_infos.append(addr_info) if bad_addrs_infos: for bad_addr_info in bad_addrs_infos: addr_infos.remove(bad_addr_info) return # Slow path in case addr is formatted differently - ip_address = ipaddress.ip_address(addr) + ip_address = ipaddress.ip_address(address) for addr_info in addr_infos: if ip_address == ipaddress.ip_address(addr_info[-1][0]): bad_addrs_infos.append(addr_info) @@ -48,4 +48,4 @@ def remove_addr_infos( for bad_addr_info in bad_addrs_infos: addr_infos.remove(bad_addr_info) return - raise ValueError(f"Address {addr} not found in addr_infos") + raise ValueError(f"Address {address} not found in addr_infos") diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..c1e1022 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,107 @@ +import socket +from typing import List + +import pytest + +from aiohappyeyeballs import AddrInfoType, pop_addr_infos_interleave, remove_addr_infos + + +def test_pop_addr_infos_interleave(): + """Test pop_addr_infos_interleave.""" + ipv6_addr_info = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:beef::", 80, 0, 0), + ) + ipv6_addr_info_2 = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:aaaa::", 80, 0, 0), + ) + ipv4_addr_info = ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("107.6.106.83", 80), + ) + addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info] + addr_info_copy = addr_info.copy() + pop_addr_infos_interleave(addr_info_copy, 1) + assert addr_info_copy == [ipv6_addr_info_2] + pop_addr_infos_interleave(addr_info_copy, 1) + assert addr_info_copy == [] + addr_info_copy = addr_info.copy() + pop_addr_infos_interleave(addr_info_copy, 2) + assert addr_info_copy == [] + + +def test_remove_addr_infos(): + """Test remove_addr_infos.""" + ipv6_addr_info = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:beef::", 80, 0, 0), + ) + ipv6_addr_info_2 = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:aaaa::", 80, 0, 0), + ) + ipv4_addr_info = ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("107.6.106.83", 80), + ) + addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info] + addr_info_copy = addr_info.copy() + remove_addr_infos(addr_info_copy, "dead:beef::") + assert addr_info_copy == [ipv6_addr_info_2, ipv4_addr_info] + remove_addr_infos(addr_info_copy, "dead:aaaa::") + assert addr_info_copy == [ipv4_addr_info] + remove_addr_infos(addr_info_copy, "107.6.106.83") + assert addr_info_copy == [] + + +def test_remove_addr_infos_slow_path(): + """Test remove_addr_infos with mis-matched formatting.""" + ipv6_addr_info = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:beef::", 80, 0, 0), + ) + ipv6_addr_info_2 = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:aaaa::", 80, 0, 0), + ) + ipv4_addr_info = ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("107.6.106.83", 80), + ) + addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info] + addr_info_copy = addr_info.copy() + remove_addr_infos(addr_info_copy, "dead:beef:0000:0000:0000:0000:0000:0000") + assert addr_info_copy == [ipv6_addr_info_2, ipv4_addr_info] + remove_addr_infos(addr_info_copy, "dead:aaaa:0000:0000:0000:0000:0000:0000") + assert addr_info_copy == [ipv4_addr_info] + with pytest.raises(ValueError, match="Address 107.6.106.2 not found in addr_infos"): + remove_addr_infos(addr_info_copy, "107.6.106.2") + assert addr_info_copy == [ipv4_addr_info]