Skip to content

Commit

Permalink
Add local_address= kwarg to open_tcp_stream
Browse files Browse the repository at this point in the history
  • Loading branch information
njsmith committed Jun 24, 2020
1 parent a3ef2e2 commit 18d4288
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 5 deletions.
3 changes: 3 additions & 0 deletions newsfragments/275.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
`trio.open_tcp_stream` has a new ``local_address=`` keyword argument
that can be used on machines with multiple IP addresses to control
which IP is used for the outgoing connection.
70 changes: 65 additions & 5 deletions trio/_highlevel_open_tcp_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,7 @@ def format_host_port(host, port):
# AF_INET6: "..."}
# this might be simpler after
async def open_tcp_stream(
host,
port,
*,
# No trailing comma b/c bpo-9232 (fixed in py36)
happy_eyeballs_delay=DEFAULT_DELAY,
host, port, *, happy_eyeballs_delay=DEFAULT_DELAY, local_address=None,
):
"""Connect to the given host and port over TCP.
Expand Down Expand Up @@ -205,13 +201,30 @@ async def open_tcp_stream(
Args:
host (str or bytes): The host to connect to. Can be an IPv4 address,
IPv6 address, or a hostname.
port (int): The port to connect to.
happy_eyeballs_delay (float): How many seconds to wait for each
connection attempt to succeed or fail before getting impatient and
starting another one in parallel. Set to `math.inf` if you want
to limit to only one connection attempt at a time (like
:func:`socket.create_connection`). Default: 0.25 (250 ms).
local_address (None or str): The local IP address or hostname to use as
the source for outgoing connections. If ``None``, we let the OS pick
the source IP.
This is useful in some exotic networking configurations where your
host has multiple IP addresses, and you want to force the use of a
specific one.
Note that if you pass an IPv4 ``local_address``, then you won't be
able to connect to IPv6 hosts, and vice-versa. If you want to take
advantage of this to force the use of IPv4 or IPv6 without
specifying an exact source address, you can use the IPv4 wildcard
address ``local_address="0.0.0.0"``, or the IPv6 wildcard address
``local_address="::"``.
Returns:
SocketStream: a :class:`~trio.abc.Stream` connected to the given server.
Expand Down Expand Up @@ -269,6 +282,53 @@ async def attempt_connect(socket_args, sockaddr, attempt_failed):
sock = socket(*socket_args)
open_sockets.add(sock)

if local_address is not None:
# TCP connections are identified by a 4-tuple:
#
# (local IP, local port, remote IP, remote port)
#
# So if a single local IP wants to make multiple connections
# to the same (remote IP, remote port) pair, then those
# connections have to use different local ports, or else TCP
# won't be able to tell them apart. OTOH, if you have multiple
# connections to different remote IP/ports, then those
# connections can share a local port.
#
# Normally, when you call bind(), the kernel will immediately
# assign a specific local port to your socket. At this point
# the kernel doesn't know which (remote IP, remote port)
# you're going to use, so it has to pick a local port that
# *no* other connection is using. That's the only way to
# guarantee that this local port will be usable later when we
# call connect(). (Alternatively, you can set SO_REUSEADDR to
# allow multiple nascent connections to share the same port,
# but then connect() might fail with EADDRNOTAVAIL if we get
# unlucky and our TCP 4-tuple ends up colliding with another
# unrelated connection.)
#
# So calling bind() before connect() works, but it disables
# sharing of local ports. This is inefficient: it makes you
# more likely to run out of local ports.
#
# But on some versions of Linux, we can re-enable sharing of
# local ports by setting a special flag. This flag tells
# bind() to only bind the IP, and not the port. That way,
# connect() is allowed to pick the the port, and it can do a
# better job of it because it knows the remote IP/port.
try:
sock.setsockopt(
trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT, 1
)
except (OSError, NameError):
pass
try:
await sock.bind((local_address, 0))
except OSError:
raise OSError(
f"local_address={local_address!r} is incompatible "
f"with remote address {sockaddr}"
)

await sock.connect(sockaddr)

# Success! Save the winning socket and cancel all outstanding
Expand Down
6 changes: 6 additions & 0 deletions trio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,9 @@
TCP_NOTSENT_LOWAT = 0x201
elif _sys.platform == "linux":
TCP_NOTSENT_LOWAT = 25

try:
IP_BIND_ADDRESS_NO_PORT
except NameError:
if _sys.platform == "linux":
IP_BIND_ADDRESS_NO_PORT = 24
53 changes: 53 additions & 0 deletions trio/tests/test_highlevel_open_tcp_stream.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest
import sys
import socket

import attr

Expand Down Expand Up @@ -112,6 +114,57 @@ async def test_open_tcp_stream_input_validation():
await open_tcp_stream("127.0.0.1", b"80")


def can_bind_127_0_0_2():
with socket.socket() as s:
try:
s.bind(("127.0.0.2", 0))
except OSError:
return False
return s.getsockname()[0] == "127.0.0.2"


async def test_local_address_real():
with trio.socket.socket() as listener:
await listener.bind(("127.0.0.1", 0))
listener.listen()

# It's hard to test local_address properly, because you need multiple
# local addresses that you can bind to. Fortunately, on most Linux
# systems, you can bind to any 127.*.*.* address, and they all go
# through the loopback interface. So we can use a non-standard
# loopback address. On other systems, the only address we know for
# certain we have is 127.0.0.1, so we can't really test local_address=
# properly -- passing local_address=127.0.0.1 is indistinguishable
# from not passing local_address= at all. But, we can still do a smoke
# test to make sure the local_address= code doesn't crash.
if can_bind_127_0_0_2():
local_address = "127.0.0.2"
else:
local_address = "127.0.0.1"

async with await open_tcp_stream(
*listener.getsockname(), local_address=local_address
) as client_stream:
assert client_stream.socket.getsockname()[0] == local_address
server_sock, remote_addr = await listener.accept()
await client_stream.aclose()
server_sock.close()
assert remote_addr[0] == local_address

# Trying to connect to an ipv4 address with the ipv6 wildcard
# local_address should fail
with pytest.raises(OSError):
await open_tcp_stream(*listener.getsockname(), local_address="::")

# But the ipv4 wildcard address should work
async with await open_tcp_stream(
*listener.getsockname(), local_address="0.0.0.0"
) as client_stream:
server_sock, remote_addr = await listener.accept()
server_sock.close()
assert remote_addr == client_stream.socket.getsockname()


# Now, thorough tests using fake sockets


Expand Down

0 comments on commit 18d4288

Please sign in to comment.