Skip to content

Commit

Permalink
Add open_ssl_over_tcp_listeners
Browse files Browse the repository at this point in the history
  • Loading branch information
njsmith committed Aug 14, 2017
1 parent 2c62a56 commit ba1b940
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 75 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference-io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ abstraction.

.. autofunction:: open_tcp_listeners

.. autofunction:: open_ssl_over_tcp_listeners


SSL / TLS support
~~~~~~~~~~~~~~~~~
Expand Down
30 changes: 29 additions & 1 deletion trio/_highlevel_ssl_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ._highlevel_open_tcp_stream import DEFAULT_DELAY

__all__ = ["open_ssl_over_tcp_stream"]
__all__ = ["open_ssl_over_tcp_stream", "open_ssl_over_tcp_listeners"]


# It might have been nice to take a ssl_protocols= argument here to set up
Expand Down Expand Up @@ -66,3 +66,31 @@ async def open_ssl_over_tcp_stream(
server_hostname=host,
https_compatible=https_compatible,
)


async def open_ssl_over_tcp_listeners(
port, ssl_context, *, host=None, https_compatible=False, backlog=None
):
"""Start listening for SSL/TLS-encrypted TCP connections to the given port.
Args:
port (int): The port to listen on. See :func:`open_tcp_listeners`.
ssl_context (~ssl.SSLContext): The SSL context to use for all incoming
connections.
host (str or None): The address to bind to; use ``None`` to bind to the
wildcard address. See :func:`open_tcp_listeners`.
https_compatible (bool): See :class:`~trio.ssl.SSLStream` for details.
backlog (int or None): See :class:`~trio.ssl.SSLStream` for details.
"""
tcp_listeners = await trio.open_tcp_listeners(
port, host=host, backlog=backlog
)
ssl_listeners = [
trio.ssl.SSLListener(
tcp_listener,
ssl_context,
https_compatible=https_compatible,
) for tcp_listener in tcp_listeners
]
return ssl_listeners
138 changes: 64 additions & 74 deletions trio/tests/test_highlevel_ssl_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,93 +3,62 @@
import attr

import trio
from trio.socket import AF_INET, SOCK_STREAM, IPPROTO_TCP
from trio.socket import socket, AF_INET, SOCK_STREAM, IPPROTO_TCP
import trio.testing
from .._util import acontextmanager
from .test_ssl import CLIENT_CTX, SERVER_CTX

from .._highlevel_ssl_helpers import open_ssl_over_tcp_stream

# this would be much simpler with a real fake network
# or just having trustme support for IP addresses so I could try connecting to
# 127.0.0.1

# Need to at least check making a successful connection, and making
# connections that fail CA and hostname validation.
#
# Also custom context and https_compatible I guess, though there isn't a whole
# lot that could go wrong here. Probably don't need to test
# happy_eyeballs_delay separately.


@attr.s
class FakeSocket:
stream = attr.ib()

async def connect(self, sockaddr):
pass

async def sendall(self, data):
await self.stream.send_all(data)

async def recv(self, max_bytes):
return await self.stream.receive_some(max_bytes)

def close(self):
# Memory{Send,Receive}Streams provide a synchronous close method just
# to support this case.
self.stream.receive_stream.close()
self.stream.send_stream.close()

# Stubs to make SocketStream happy:
def setsockopt(self, *args, **kwargs):
pass

def getpeername(self, *args):
pass

type = trio.socket.SOCK_STREAM
did_shutdown_SHUT_WR = False


# No matter who you connect to, you end up talking to an echo server with a
# cert for trio-test-1.example.com.
from .._highlevel_ssl_helpers import (
open_ssl_over_tcp_stream, open_ssl_over_tcp_listeners
)


# This is used by test_open_ssl_over_tcp_stream to test it, and it also uses
# open_ssl_over_tcp_listeners so it tests that too.
async def start_echo_server(nursery):

async def accept_loop(listener):
async with listener:
while True:
server_stream = await listener.accept()
nursery.spawn(echo_server, server_stream)

async def echo_server(server_stream):
async with server_stream:
try:
while True:
data = await server_stream.receive_some(10000)
if not data:
break
await server_stream.send_all(data)
except trio.BrokenStreamError:
pass

(listener,) = await trio.open_ssl_over_tcp_listeners(
0, SERVER_CTX, host="127.0.0.1"
)
nursery.spawn(accept_loop, listener)
return listener.transport_listener.socket.getsockname()


# Resolver that always returns the given sockaddr, no matter what host/port
# you ask for.
@attr.s
class FakeNetwork(trio.abc.HostnameResolver, trio.abc.SocketFactory):
nursery = attr.ib()
class FakeHostnameResolver(trio.abc.HostnameResolver):
sockaddr = attr.ib()

async def getaddrinfo(self, *args):
return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", ("1.1.1.1", 443))]
return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)]

async def getnameinfo(self, *args): # pragma: no cover
raise NotImplementedError

def is_trio_socket(self, obj):
return isinstance(obj, FakeSocket)

def socket(self, family, type, proto):
client_stream, server_stream = trio.testing.memory_stream_pair()
self.nursery.spawn(self.echo_server, server_stream)
return FakeSocket(client_stream)

async def echo_server(self, raw_server_stream):
ssl_server_stream = trio.ssl.SSLStream(
raw_server_stream,
SERVER_CTX,
server_side=True,
)
while True:
data = await ssl_server_stream.receive_some(10000)
if not data:
break
await ssl_server_stream.send_all(data)


async def test_open_ssl_over_tcp_stream():
async with trio.open_nursery() as nursery:
network = FakeNetwork(nursery)
trio.socket.set_custom_hostname_resolver(network)
trio.socket.set_custom_socket_factory(network)
sockaddr = await start_echo_server(nursery)
hostname_resolver = FakeHostnameResolver(sockaddr)
trio.socket.set_custom_hostname_resolver(hostname_resolver)

# We don't have the right trust set up
# (checks that ssl_context=None is doing some validation)
Expand All @@ -113,6 +82,8 @@ async def test_open_ssl_over_tcp_stream():
80,
ssl_context=CLIENT_CTX,
)
assert isinstance(stream, trio.ssl.SSLStream)
assert stream.server_hostname == "trio-test-1.example.org"
await stream.send_all(b"x")
assert await stream.receive_some(1) == b"x"
await stream.aclose()
Expand All @@ -129,5 +100,24 @@ async def test_open_ssl_over_tcp_stream():
)
assert stream._https_compatible

# We've left abandoned server tasks behind; clean them up.
# Stop the echo server
nursery.cancel_scope.cancel()


async def test_open_ssl_over_tcp_listeners():
(listener,) = await open_ssl_over_tcp_listeners(
0, SERVER_CTX, host="127.0.0.1"
)
async with listener:
assert isinstance(listener, trio.ssl.SSLListener)
tl = listener.transport_listener
assert isinstance(tl, trio.SocketListener)
assert tl.socket.getsockname()[0] == "127.0.0.1"

assert not listener._https_compatible

(listener,) = await open_ssl_over_tcp_listeners(
0, SERVER_CTX, host="127.0.0.1", https_compatible=True
)
async with listener:
assert listener._https_compatible

0 comments on commit ba1b940

Please sign in to comment.