diff --git a/README.rst b/README.rst index 6c56b4ee20..21ae7263a3 100644 --- a/README.rst +++ b/README.rst @@ -101,7 +101,7 @@ asking for help in our `chat room `__, `filing a bug `__, or `posting a question on StackOverflow -`__, and +`__, and we'll do our best to help you out. **Trio is awesome and I want to help make it more awesome!** You're diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index 24b0369899..aa558724ea 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -20,5 +20,5 @@

Need help? Try chat or StackOverflow.

-{% endblock %} \ No newline at end of file +href="https://stackoverflow.com/questions/ask?tags=python+python-trio">StackOverflow.

+{% endblock %} diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index e576dc4d3c..b0ea3d03fd 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -195,6 +195,8 @@ abstraction. .. autofunction:: serve_ssl_over_tcp +.. autofunction:: open_unix_socket + .. autoclass:: SocketStream :members: :undoc-members: diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index c39f7eaff9..77d81e24ef 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -108,8 +108,8 @@ If you get lost or confused... ...then we want to know! We have a friendly `chat channel `__, you can ask questions -`using the "trio" tag on StackOverflow -`__, or just +`using the "python-trio" tag on StackOverflow +`__, or just `file a bug `__ (if our documentation is confusing, that's our fault, and we want to fix it!). diff --git a/newsfragments/401.feature.rst b/newsfragments/401.feature.rst new file mode 100644 index 0000000000..8800af4eb9 --- /dev/null +++ b/newsfragments/401.feature.rst @@ -0,0 +1 @@ +Add unix client socket support. \ No newline at end of file diff --git a/trio/__init__.py b/trio/__init__.py index 6077c04ef7..1d10c58401 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -50,6 +50,9 @@ from ._highlevel_open_tcp_listeners import * __all__ += _highlevel_open_tcp_listeners.__all__ +from ._highlevel_open_unix_stream import * +__all__ += _highlevel_open_unix_stream.__all__ + from ._highlevel_ssl_helpers import * __all__ += _highlevel_ssl_helpers.__all__ diff --git a/trio/_core/_ki.py b/trio/_core/_ki.py index 1faa295644..968299404c 100644 --- a/trio/_core/_ki.py +++ b/trio/_core/_ki.py @@ -176,8 +176,10 @@ def wrapper(*args, **kwargs): @contextmanager def ki_manager(deliver_cb, restrict_keyboard_interrupt_to_checkpoints): - if (threading.current_thread() != threading.main_thread() - or signal.getsignal(signal.SIGINT) != signal.default_int_handler): + if ( + threading.current_thread() != threading.main_thread() + or signal.getsignal(signal.SIGINT) != signal.default_int_handler + ): yield return diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 96179d3d54..ca1f0a9440 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -383,8 +383,10 @@ def _add_exc(self, exc): self.cancel_scope.cancel() def _check_nursery_closed(self): - if (not self._nested_child_running and not self._children - and not self._pending_starts): + if ( + not self._nested_child_running and not self._children + and not self._pending_starts + ): self._closed = True if self._parent_waiting_in_aexit: self._parent_waiting_in_aexit = False @@ -1477,8 +1479,10 @@ async def checkpoint_if_cancelled(): """ task = current_task() - if (task._pending_cancel_scope() is not None or - (task is task._runner.main_task and task._runner.ki_pending)): + if ( + task._pending_cancel_scope() is not None or + (task is task._runner.main_task and task._runner.ki_pending) + ): await _core.checkpoint() assert False # pragma: no cover task._cancel_points += 1 diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py index bc6b5de2ef..00e972a1b4 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/tests/test_run.py @@ -1203,8 +1203,10 @@ def cb(x): for i in range(100): token.run_sync_soon(cb, i, idempotent=True) await wait_all_tasks_blocked() - if (sys.version_info < (3, 6) - and platform.python_implementation() == "CPython"): + if ( + sys.version_info < (3, 6) + and platform.python_implementation() == "CPython" + ): # no order guarantees record.sort() # Otherwise, we guarantee FIFO diff --git a/trio/_highlevel_open_unix_stream.py b/trio/_highlevel_open_unix_stream.py new file mode 100644 index 0000000000..522ddb7104 --- /dev/null +++ b/trio/_highlevel_open_unix_stream.py @@ -0,0 +1,42 @@ +import trio +from trio._highlevel_open_tcp_stream import close_on_error +from trio.socket import socket, SOCK_STREAM + +try: + from trio.socket import AF_UNIX + has_unix = True +except ImportError: + has_unix = False + +__all__ = ["open_unix_socket"] + + +async def open_unix_socket(filename,): + """Opens a connection to the specified + `Unix domain socket `__. + + You must have read/write permission on the specified file to connect. + + Args: + filename (str or bytes): The filename to open the connection to. + + Returns: + SocketStream: a :class:`~trio.abc.Stream` connected to the given file. + + Raises: + OSError: If the socket file could not be connected to. + RuntimeError: If AF_UNIX sockets are not supported. + """ + if not has_unix: + raise RuntimeError("Unix sockets are not supported on this platform") + + if filename is None: + raise ValueError("Filename cannot be None") + + # much more simplified logic vs tcp sockets - one socket type and only one + # possible location to connect to + sock = socket(AF_UNIX, SOCK_STREAM) + with close_on_error(sock): + await sock.connect(filename) + + return trio.SocketStream(sock) diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index 903f5ecabb..5d00e44c46 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -43,9 +43,9 @@ class SocketStream(HalfCloseableStream): socket: The trio socket object to wrap. Must have type ``SOCK_STREAM``, and be connected. - By default, :class:`SocketStream` enables ``TCP_NODELAY``, and (on - platforms where it's supported) enables ``TCP_NOTSENT_LOWAT`` with a - reasonable buffer size (currently 16 KiB) – see `issue #72 + By default for TCP sockets, :class:`SocketStream` enables ``TCP_NODELAY``, + and (on platforms where it's supported) enables ``TCP_NOTSENT_LOWAT`` with + a reasonable buffer size (currently 16 KiB) – see `issue #72 `__ for discussion. You can of course override these defaults by calling :meth:`setsockopt`. diff --git a/trio/_socket.py b/trio/_socket.py index d2b0a1a8c6..4ad58b1493 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -1,15 +1,14 @@ -from functools import wraps as _wraps, partial as _partial +import os as _os import socket as _stdlib_socket import sys as _sys -import os as _os -from contextlib import contextmanager as _contextmanager -import errno as _errno +from functools import wraps as _wraps import idna as _idna from . import _core from ._deprecate import deprecated from ._threads import run_sync_in_worker_thread +from ._util import fspath __all__ = [] @@ -462,8 +461,10 @@ def dup(self): async def bind(self, address): await _core.checkpoint() address = await self._resolve_local_address(address) - if (hasattr(_stdlib_socket, "AF_UNIX") and self.family == AF_UNIX - and address[0]): + if ( + hasattr(_stdlib_socket, "AF_UNIX") and self.family == AF_UNIX + and address[0] + ): # Use a thread for the filesystem traversal (unless it's an # abstract domain socket) return await run_sync_in_worker_thread(self._sock.bind, address) @@ -504,6 +505,11 @@ async def _resolve_address(self, address, flags): "address should be a (host, port, [flowinfo, [scopeid]]) " "tuple" ) + elif self._sock.family == AF_UNIX: + await _core.checkpoint() + # unwrap path-likes + return fspath(address) + else: await _core.checkpoint() return address diff --git a/trio/_ssl.py b/trio/_ssl.py index 1f89e880c9..8e1c337c2e 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -630,9 +630,11 @@ async def receive_some(self, max_bytes): # For some reason, EOF before handshake sometimes raises # SSLSyscallError instead of SSLEOFError (e.g. on my linux # laptop, but not on appveyor). Thanks openssl. - if (self._https_compatible - and isinstance(exc.__cause__, - (SSLEOFError, SSLSyscallError))): + if ( + self._https_compatible and + isinstance(exc.__cause__, + (SSLEOFError, SSLSyscallError)) + ): return b"" else: raise @@ -647,8 +649,10 @@ async def receive_some(self, max_bytes): # BROKEN. But that's actually fine, because after getting an # EOF on TLS then the only thing you can do is close the # stream, and closing doesn't care about the state. - if (self._https_compatible - and isinstance(exc.__cause__, SSLEOFError)): + if ( + self._https_compatible + and isinstance(exc.__cause__, SSLEOFError) + ): return b"" else: raise diff --git a/trio/_util.py b/trio/_util.py index 7cae6d0fe9..e7e02a6df1 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -3,6 +3,7 @@ import os import sys from functools import wraps +import typing as t import async_generator @@ -14,11 +15,8 @@ from . import _core __all__ = [ - "signal_raise", - "aiter_compat", - "acontextmanager", - "ConflictDetector", - "fixup_module_metadata", + "signal_raise", "aiter_compat", "acontextmanager", "ConflictDetector", + "fixup_module_metadata", "fspath" ] # Equivalent to the C function raise(), which Python doesn't wrap @@ -59,6 +57,7 @@ # - generating synthetic signals for tests # and for both of those purposes, 'raise' works fine. import cffi + _ffi = cffi.FFI() _ffi.cdef("int raise(int);") _lib = _ffi.dlopen("api-ms-win-crt-runtime-l1-1-0.dll") @@ -133,8 +132,11 @@ async def __aexit__(self, type, value, traceback): # Likewise, avoid suppressing if a StopIteration exception # was passed to throw() and later wrapped into a RuntimeError # (see PEP 479). - if (isinstance(value, (StopIteration, StopAsyncIteration)) - and exc.__cause__ is value): + if ( + isinstance(value, + (StopIteration, StopAsyncIteration)) + and exc.__cause__ is value + ): return False raise except: @@ -253,3 +255,45 @@ def fix_one(obj): for objname in namespace["__all__"]: obj = namespace[objname] fix_one(obj) + + +# This is copied from PEP 519 as the implementation of os.fspath for +# Python 3.5. See: https://www.python.org/dev/peps/pep-0519/#os +# The input typehint is removed as there is no os.PathLike on 3.5. + + +def fspath(path) -> t.Union[str, bytes]: + """Return the string representation of the path. + + If str or bytes is passed in, it is returned unchanged. If __fspath__() + returns something other than str or bytes then TypeError is raised. If + this function is given something that is not str, bytes, or os.PathLike + then TypeError is raised. + """ + if isinstance(path, (str, bytes)): + return path + + # Work from the object's type to match method resolution of other magic + # methods. + path_type = type(path) + try: + path = path_type.__fspath__(path) + except AttributeError: + if hasattr(path_type, '__fspath__'): + raise + else: + if isinstance(path, (str, bytes)): + return path + else: + raise TypeError( + "expected __fspath__() to return str or bytes, " + "not " + type(path).__name__ + ) + + raise TypeError( + "expected str, bytes or os.PathLike object, not " + path_type.__name__ + ) + + +if hasattr(os, "fspath"): + fspath = os.fspath diff --git a/trio/testing/_checkpoints.py b/trio/testing/_checkpoints.py index e878d560e1..dbf14566e9 100644 --- a/trio/testing/_checkpoints.py +++ b/trio/testing/_checkpoints.py @@ -15,11 +15,19 @@ def _assert_yields_or_not(expected): try: yield finally: - if (expected and (task._cancel_points == orig_cancel - or task._schedule_points == orig_schedule)): + if ( + expected and ( + task._cancel_points == orig_cancel + or task._schedule_points == orig_schedule + ) + ): raise AssertionError("assert_checkpoints block did not yield!") - elif (not expected and (task._cancel_points != orig_cancel - or task._schedule_points != orig_schedule)): + elif ( + not expected and ( + task._cancel_points != orig_cancel + or task._schedule_points != orig_schedule + ) + ): raise AssertionError("assert_no_yields block yielded!") diff --git a/trio/tests/test_exports.py b/trio/tests/test_exports.py index a3da988fb3..d5ad1b3381 100644 --- a/trio/tests/test_exports.py +++ b/trio/tests/test_exports.py @@ -11,8 +11,10 @@ def test_core_is_properly_reexported(): for symbol in _core.__all__: found = 0 for source in sources: - if (symbol in source.__all__ - and getattr(source, symbol) is getattr(_core, symbol)): + if ( + symbol in source.__all__ + and getattr(source, symbol) is getattr(_core, symbol) + ): found += 1 print(symbol, found) assert found == 1 diff --git a/trio/tests/test_highlevel_open_unix_stream.py b/trio/tests/test_highlevel_open_unix_stream.py new file mode 100644 index 0000000000..6a5948b634 --- /dev/null +++ b/trio/tests/test_highlevel_open_unix_stream.py @@ -0,0 +1,57 @@ +import os +import socket +import tempfile + +import pytest + +from trio import open_unix_socket, Path + +try: + from socket import AF_UNIX +except ImportError: + pytestmark = pytest.mark.skip("Needs unix socket support") + + +async def get_server_socket(): + name = Path() / tempfile.gettempdir() / "test.sock" + try: + await name.unlink() + except OSError: + pass + + serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + serv_sock.bind(name.__fspath__()) # bpo-32562 + serv_sock.listen(1) + + return name, serv_sock + + +async def _do_test_on_sock(serv_sock, unix_socket): + # shared code between some tests + client, _ = serv_sock.accept() + await unix_socket.send_all(b"test") + assert client.recv(2048) == b"test" + + client.sendall(b"response") + received = await unix_socket.receive_some(2048) + assert received == b"response" + + +async def test_open_bad_socket(): + # mktemp is marked as insecure, but that's okay, we don't want the file to + # exist + name = os.path.join(tempfile.gettempdir(), tempfile.mktemp()) + with pytest.raises(FileNotFoundError): + await open_unix_socket(name) + + +async def test_open_unix_socket(): + name, serv_sock = await get_server_socket() + unix_socket = await open_unix_socket(name.__fspath__()) + await _do_test_on_sock(serv_sock, unix_socket) + + +async def test_open_unix_socket_with_path(): + name, serv_sock = await get_server_socket() + unix_socket = await open_unix_socket(name) + await _do_test_on_sock(serv_sock, unix_socket) \ No newline at end of file diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py index fced9558b4..371e519910 100644 --- a/trio/tests/test_ssl.py +++ b/trio/tests/test_ssl.py @@ -66,9 +66,12 @@ # will hopefully be fixed soon import sys WORKAROUND_PYPY_BUG = False -if (hasattr(sys, "pypy_version_info") and - ((sys.pypy_version_info < (5, 9)) or - (sys.pypy_version_info[:4] == (5, 9, 0, "alpha")))): +if ( + hasattr(sys, "pypy_version_info") and ( + (sys.pypy_version_info < (5, 9)) or + (sys.pypy_version_info[:4] == (5, 9, 0, "alpha")) + ) +): WORKAROUND_PYPY_BUG = True