From 18d428863b9d4ee427d133cafa7259115dd0c2a8 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 24 Jun 2020 10:52:14 -0700 Subject: [PATCH] Add local_address= kwarg to open_tcp_stream Fixes gh-275 --- newsfragments/275.feature.rst | 3 + trio/_highlevel_open_tcp_stream.py | 70 ++++++++++++++++++-- trio/socket.py | 6 ++ trio/tests/test_highlevel_open_tcp_stream.py | 53 +++++++++++++++ 4 files changed, 127 insertions(+), 5 deletions(-) create mode 100644 newsfragments/275.feature.rst diff --git a/newsfragments/275.feature.rst b/newsfragments/275.feature.rst new file mode 100644 index 0000000000..26b8ebb521 --- /dev/null +++ b/newsfragments/275.feature.rst @@ -0,0 +1,3 @@ +`trio.open_tcp_stream` has a new ``local_address=`` keyword argument +that can be used on machines with multiple IP addresses to control +which IP is used for the outgoing connection. diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py index 99cf8bb1c3..27b1ed0672 100644 --- a/trio/_highlevel_open_tcp_stream.py +++ b/trio/_highlevel_open_tcp_stream.py @@ -165,11 +165,7 @@ def format_host_port(host, port): # AF_INET6: "..."} # this might be simpler after async def open_tcp_stream( - host, - port, - *, - # No trailing comma b/c bpo-9232 (fixed in py36) - happy_eyeballs_delay=DEFAULT_DELAY, + host, port, *, happy_eyeballs_delay=DEFAULT_DELAY, local_address=None, ): """Connect to the given host and port over TCP. @@ -205,13 +201,30 @@ async def open_tcp_stream( Args: host (str or bytes): The host to connect to. Can be an IPv4 address, IPv6 address, or a hostname. + port (int): The port to connect to. + happy_eyeballs_delay (float): How many seconds to wait for each connection attempt to succeed or fail before getting impatient and starting another one in parallel. Set to `math.inf` if you want to limit to only one connection attempt at a time (like :func:`socket.create_connection`). Default: 0.25 (250 ms). + local_address (None or str): The local IP address or hostname to use as + the source for outgoing connections. If ``None``, we let the OS pick + the source IP. + + This is useful in some exotic networking configurations where your + host has multiple IP addresses, and you want to force the use of a + specific one. + + Note that if you pass an IPv4 ``local_address``, then you won't be + able to connect to IPv6 hosts, and vice-versa. If you want to take + advantage of this to force the use of IPv4 or IPv6 without + specifying an exact source address, you can use the IPv4 wildcard + address ``local_address="0.0.0.0"``, or the IPv6 wildcard address + ``local_address="::"``. + Returns: SocketStream: a :class:`~trio.abc.Stream` connected to the given server. @@ -269,6 +282,53 @@ async def attempt_connect(socket_args, sockaddr, attempt_failed): sock = socket(*socket_args) open_sockets.add(sock) + if local_address is not None: + # TCP connections are identified by a 4-tuple: + # + # (local IP, local port, remote IP, remote port) + # + # So if a single local IP wants to make multiple connections + # to the same (remote IP, remote port) pair, then those + # connections have to use different local ports, or else TCP + # won't be able to tell them apart. OTOH, if you have multiple + # connections to different remote IP/ports, then those + # connections can share a local port. + # + # Normally, when you call bind(), the kernel will immediately + # assign a specific local port to your socket. At this point + # the kernel doesn't know which (remote IP, remote port) + # you're going to use, so it has to pick a local port that + # *no* other connection is using. That's the only way to + # guarantee that this local port will be usable later when we + # call connect(). (Alternatively, you can set SO_REUSEADDR to + # allow multiple nascent connections to share the same port, + # but then connect() might fail with EADDRNOTAVAIL if we get + # unlucky and our TCP 4-tuple ends up colliding with another + # unrelated connection.) + # + # So calling bind() before connect() works, but it disables + # sharing of local ports. This is inefficient: it makes you + # more likely to run out of local ports. + # + # But on some versions of Linux, we can re-enable sharing of + # local ports by setting a special flag. This flag tells + # bind() to only bind the IP, and not the port. That way, + # connect() is allowed to pick the the port, and it can do a + # better job of it because it knows the remote IP/port. + try: + sock.setsockopt( + trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT, 1 + ) + except (OSError, NameError): + pass + try: + await sock.bind((local_address, 0)) + except OSError: + raise OSError( + f"local_address={local_address!r} is incompatible " + f"with remote address {sockaddr}" + ) + await sock.connect(sockaddr) # Success! Save the winning socket and cancel all outstanding diff --git a/trio/socket.py b/trio/socket.py index ebbccd50ea..0f20d2f5d4 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -194,3 +194,9 @@ TCP_NOTSENT_LOWAT = 0x201 elif _sys.platform == "linux": TCP_NOTSENT_LOWAT = 25 + +try: + IP_BIND_ADDRESS_NO_PORT +except NameError: + if _sys.platform == "linux": + IP_BIND_ADDRESS_NO_PORT = 24 diff --git a/trio/tests/test_highlevel_open_tcp_stream.py b/trio/tests/test_highlevel_open_tcp_stream.py index 9fd0f3992a..c04a86d9c4 100644 --- a/trio/tests/test_highlevel_open_tcp_stream.py +++ b/trio/tests/test_highlevel_open_tcp_stream.py @@ -1,4 +1,6 @@ import pytest +import sys +import socket import attr @@ -112,6 +114,57 @@ async def test_open_tcp_stream_input_validation(): await open_tcp_stream("127.0.0.1", b"80") +def can_bind_127_0_0_2(): + with socket.socket() as s: + try: + s.bind(("127.0.0.2", 0)) + except OSError: + return False + return s.getsockname()[0] == "127.0.0.2" + + +async def test_local_address_real(): + with trio.socket.socket() as listener: + await listener.bind(("127.0.0.1", 0)) + listener.listen() + + # It's hard to test local_address properly, because you need multiple + # local addresses that you can bind to. Fortunately, on most Linux + # systems, you can bind to any 127.*.*.* address, and they all go + # through the loopback interface. So we can use a non-standard + # loopback address. On other systems, the only address we know for + # certain we have is 127.0.0.1, so we can't really test local_address= + # properly -- passing local_address=127.0.0.1 is indistinguishable + # from not passing local_address= at all. But, we can still do a smoke + # test to make sure the local_address= code doesn't crash. + if can_bind_127_0_0_2(): + local_address = "127.0.0.2" + else: + local_address = "127.0.0.1" + + async with await open_tcp_stream( + *listener.getsockname(), local_address=local_address + ) as client_stream: + assert client_stream.socket.getsockname()[0] == local_address + server_sock, remote_addr = await listener.accept() + await client_stream.aclose() + server_sock.close() + assert remote_addr[0] == local_address + + # Trying to connect to an ipv4 address with the ipv6 wildcard + # local_address should fail + with pytest.raises(OSError): + await open_tcp_stream(*listener.getsockname(), local_address="::") + + # But the ipv4 wildcard address should work + async with await open_tcp_stream( + *listener.getsockname(), local_address="0.0.0.0" + ) as client_stream: + server_sock, remote_addr = await listener.accept() + server_sock.close() + assert remote_addr == client_stream.socket.getsockname() + + # Now, thorough tests using fake sockets