Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Memory Streams #866

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

from ._signals import catch_signals, open_signal_receiver

from ._highlevel_memory import memory_connect, MemoryStream, MemoryListener

from ._highlevel_socket import SocketStream, SocketListener

from ._file_io import open_file, wrap_file
Expand All @@ -54,7 +56,7 @@

from ._highlevel_serve_listeners import serve_listeners

from ._highlevel_open_tcp_stream import open_tcp_stream
from ._highlevel_open_tcp_stream import open_tcp_stream, format_host_port

from ._highlevel_open_tcp_listeners import open_tcp_listeners, serve_tcp

Expand Down
61 changes: 61 additions & 0 deletions trio/_highlevel_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# "High-level" in-memory networking interface

from . import _core, sleep, open_memory_channel, StapledStream
from ._sync import Event, Queue
from .abc import HalfCloseableStream, Listener

__all__ = ["memory_connect", "MemoryStream", "MemoryListener"]

from . import testing

# An efficient way to communicate between :
# - coroutines ? -> TODO perf
# - eventloops ? -> TODO perf
# - threads ? -> TODO perf
# - processes ? -> TODO perf
#: List of channels (one per endpoint) to communicate client streams after connection has been accepted
memory_endpoints = {}


async def memory_connect(endpoint):
# we might need to wait for the endpoint to appear in the dict
rec_chan = memory_endpoints.get(endpoint)
while rec_chan is None:
await sleep(.1)

return await rec_chan[1].receive()

MemoryStream = StapledStream


################################################################
# InProcListener
################################################################


class MemoryListener(Listener):

def __init__(self, endpoint, accept_hook=None):
self.accept_hook = accept_hook
self.accepted_streams = list()
self.endpoint = endpoint
memory_endpoints[self.endpoint] = open_memory_channel(1)
self.closed = False

async def accept(self):
await _core.checkpoint()
if self.closed:
raise _core.ClosedResourceError(self)
if self.accept_hook is not None:
await self.accept_hook()

client, server = testing.memory_stream_pair()
await memory_endpoints[self.endpoint][0].send(client)
self.accepted_streams.append(server)
return server

async def aclose(self):
self.closed = True
await _core.checkpoint()


16 changes: 15 additions & 1 deletion trio/_highlevel_open_tcp_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import trio
from . import socket as tsocket
from ._highlevel_open_tcp_stream import format_host_port

__all__ = ["open_tcp_listeners", "serve_tcp"]

Expand Down Expand Up @@ -44,6 +45,15 @@ def _compute_backlog(backlog):
return min(backlog, 0xffff)


async def open_memory_listeners(endpoint):
"""
SAME API as open_tcp_listeners
:param endpoint
:return:
"""
return [trio.MemoryListener(endpoint)]


async def open_tcp_listeners(port, *, host=None, backlog=None):
"""Create :class:`SocketListener` objects to listen for TCP connections.

Expand Down Expand Up @@ -162,6 +172,7 @@ async def serve_tcp(
*,
host=None,
backlog=None,
testing=False,
handler_nursery=None,
task_status=trio.TASK_STATUS_IGNORED
):
Expand Down Expand Up @@ -228,7 +239,10 @@ async def serve_tcp(
This function only returns when cancelled.

"""
listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog)
if testing:
listeners = await trio.open_memory_listeners(format_host_port(host, port))
else:
listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog)
await trio.serve_listeners(
handler,
listeners,
Expand Down
8 changes: 7 additions & 1 deletion trio/_highlevel_open_tcp_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import trio
from trio.socket import getaddrinfo, SOCK_STREAM, socket
from trio._highlevel_memory import memory_connect

__all__ = ["open_tcp_stream"]
__all__ = ["open_tcp_stream", "format_host_port"]

# Implementation of RFC 6555 "Happy eyeballs"
# https://tools.ietf.org/html/rfc6555
Expand Down Expand Up @@ -170,6 +171,7 @@ async def open_tcp_stream(
host,
port,
*,
testing=False, # might be nicer to implement as part of the procotol to connect (like zmq's inproc://)
# No trailing comma b/c bpo-9232 (fixed in py36)
happy_eyeballs_delay=DEFAULT_DELAY
):
Expand Down Expand Up @@ -236,6 +238,10 @@ async def open_tcp_stream(
if happy_eyeballs_delay is None:
happy_eyeballs_delay = DEFAULT_DELAY

# Early return for testing usecase
if testing:
return memory_connect(format_host_port(host, port))

targets = await getaddrinfo(host, port, type=SOCK_STREAM)

# I don't think this can actually happen -- if there are no results,
Expand Down
131 changes: 131 additions & 0 deletions trio/tests/test_highlevel_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import pytest

from .. import _core
from ..testing import (
check_half_closeable_stream, wait_all_tasks_blocked, assert_checkpoints
)
from .._highlevel_memory import *

from .. import testing


async def test_MemoryStream_send_all():
BIG = 10000000

a, b = testing.memory_stream_pair()

# Check a send_all that has to be split into multiple parts (on most
# platforms... on Windows every send() either succeeds or fails as a
# whole)
async def sender():
data = bytearray(BIG)
await a.send_all(data)
# send_all uses memoryviews internally, which temporarily "lock"
# the object they view. If it doesn't clean them up properly, then
# some bytearray operations might raise an error afterwards, which
# would be a pretty weird and annoying side-effect to spring on
# users. So test that this doesn't happen, by forcing the
# bytearray's underlying buffer to be realloc'ed:
data += bytes(BIG)
# (Note: the above line of code doesn't do a very good job at
# testing anything, because:
# - on CPython, the refcount GC generally cleans up memoryviews
# for us even if we're sloppy.
# - on PyPy3, at least as of 5.7.0, the memoryview code and the
# bytearray code conspire so that resizing never fails – if
# resizing forces the bytearray's internal buffer to move, then
# all memoryview references are automagically updated (!!).
# See:
# https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227
# But I'm leaving the test here in hopes that if this ever changes
# and we break our implementation of send_all, then we'll get some
# early warning...)

async def receiver():
# Make sure the sender fills up the kernel buffers and blocks
await wait_all_tasks_blocked()
nbytes = 0
while nbytes < BIG:
nbytes += len(await b.receive_some(BIG))
assert nbytes == BIG

async with _core.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(receiver)

# We know that we received BIG bytes of NULs so far. Make sure that
# was all the data in there.
await a.send_all(b"e")
assert await b.receive_some(10) == b"e"
await a.send_eof()
assert await b.receive_some(10) == b""


async def fill_stream(s):
async def sender():
while True:
await s.send_all(b"x" * 10000)

async def waiter(nursery):
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()

async with _core.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(waiter, nursery)


async def test_MemoryStream_generic():
async def stream_maker():
left, right = testing.memory_stream_pair()
return left, right

async def clogged_stream_maker():
left, right = await stream_maker()
await fill_stream(left)
await fill_stream(right)
return left, right

await check_half_closeable_stream(stream_maker, clogged_stream_maker)


async def test_MemoryListener():

async def listener(endpoint, nursery):
listener = MemoryListener(endpoint)
# Only wait for one client
with assert_checkpoints():
server_stream = await listener.accept()
assert isinstance(server_stream, MemoryStream)
# and closes
with assert_checkpoints():
await listener.aclose()

with assert_checkpoints():
await listener.aclose()

# Check that we cannot accept after closing
with assert_checkpoints():
with pytest.raises(_core.ClosedResourceError):
await listener.accept()

await server_stream.aclose()
nursery.cancel_scope.cancel()

async def client(endpoint):
client_stream = await memory_connect(endpoint)
# client disconnecting immediately
await client_stream.aclose()

async with _core.open_nursery() as nursery:
nursery.start_soon(client, "test_endpoint")
nursery.start_soon(listener, "test_endpoint", nursery)



async def test_memory_stream_works_when_peer_has_already_closed():
stream_a, stream_b = testing.memory_stream_pair()
await stream_b.send_all(b"x")
await stream_b.aclose()
assert await stream_a.receive_some(1) == b"x"
assert await stream_a.receive_some(1) == b""