diff --git a/README.rst b/README.rst
index 6c56b4ee20..21ae7263a3 100644
--- a/README.rst
+++ b/README.rst
@@ -101,7 +101,7 @@ asking for help in our `chat room
`__, `filing a bug
`__, or `posting a
question on StackOverflow
-`__, and
+`__, and
we'll do our best to help you out.
**Trio is awesome and I want to help make it more awesome!** You're
diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html
index 24b0369899..aa558724ea 100644
--- a/docs/source/_templates/layout.html
+++ b/docs/source/_templates/layout.html
@@ -20,5 +20,5 @@
Need help? Try chat or StackOverflow.
-{% endblock %}
\ No newline at end of file
+href="https://stackoverflow.com/questions/ask?tags=python+python-trio">StackOverflow.
+{% endblock %}
diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst
index e576dc4d3c..b0ea3d03fd 100644
--- a/docs/source/reference-io.rst
+++ b/docs/source/reference-io.rst
@@ -195,6 +195,8 @@ abstraction.
.. autofunction:: serve_ssl_over_tcp
+.. autofunction:: open_unix_socket
+
.. autoclass:: SocketStream
:members:
:undoc-members:
diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst
index c39f7eaff9..77d81e24ef 100644
--- a/docs/source/tutorial.rst
+++ b/docs/source/tutorial.rst
@@ -108,8 +108,8 @@ If you get lost or confused...
...then we want to know! We have a friendly `chat channel
`__, you can ask questions
-`using the "trio" tag on StackOverflow
-`__, or just
+`using the "python-trio" tag on StackOverflow
+`__, or just
`file a bug `__ (if
our documentation is confusing, that's our fault, and we want to fix
it!).
diff --git a/newsfragments/401.feature.rst b/newsfragments/401.feature.rst
new file mode 100644
index 0000000000..8800af4eb9
--- /dev/null
+++ b/newsfragments/401.feature.rst
@@ -0,0 +1 @@
+Add unix client socket support.
\ No newline at end of file
diff --git a/trio/__init__.py b/trio/__init__.py
index 6077c04ef7..1d10c58401 100644
--- a/trio/__init__.py
+++ b/trio/__init__.py
@@ -50,6 +50,9 @@
from ._highlevel_open_tcp_listeners import *
__all__ += _highlevel_open_tcp_listeners.__all__
+from ._highlevel_open_unix_stream import *
+__all__ += _highlevel_open_unix_stream.__all__
+
from ._highlevel_ssl_helpers import *
__all__ += _highlevel_ssl_helpers.__all__
diff --git a/trio/_core/_ki.py b/trio/_core/_ki.py
index 1faa295644..968299404c 100644
--- a/trio/_core/_ki.py
+++ b/trio/_core/_ki.py
@@ -176,8 +176,10 @@ def wrapper(*args, **kwargs):
@contextmanager
def ki_manager(deliver_cb, restrict_keyboard_interrupt_to_checkpoints):
- if (threading.current_thread() != threading.main_thread()
- or signal.getsignal(signal.SIGINT) != signal.default_int_handler):
+ if (
+ threading.current_thread() != threading.main_thread()
+ or signal.getsignal(signal.SIGINT) != signal.default_int_handler
+ ):
yield
return
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 96179d3d54..ca1f0a9440 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -383,8 +383,10 @@ def _add_exc(self, exc):
self.cancel_scope.cancel()
def _check_nursery_closed(self):
- if (not self._nested_child_running and not self._children
- and not self._pending_starts):
+ if (
+ not self._nested_child_running and not self._children
+ and not self._pending_starts
+ ):
self._closed = True
if self._parent_waiting_in_aexit:
self._parent_waiting_in_aexit = False
@@ -1477,8 +1479,10 @@ async def checkpoint_if_cancelled():
"""
task = current_task()
- if (task._pending_cancel_scope() is not None or
- (task is task._runner.main_task and task._runner.ki_pending)):
+ if (
+ task._pending_cancel_scope() is not None or
+ (task is task._runner.main_task and task._runner.ki_pending)
+ ):
await _core.checkpoint()
assert False # pragma: no cover
task._cancel_points += 1
diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py
index bc6b5de2ef..00e972a1b4 100644
--- a/trio/_core/tests/test_run.py
+++ b/trio/_core/tests/test_run.py
@@ -1203,8 +1203,10 @@ def cb(x):
for i in range(100):
token.run_sync_soon(cb, i, idempotent=True)
await wait_all_tasks_blocked()
- if (sys.version_info < (3, 6)
- and platform.python_implementation() == "CPython"):
+ if (
+ sys.version_info < (3, 6)
+ and platform.python_implementation() == "CPython"
+ ):
# no order guarantees
record.sort()
# Otherwise, we guarantee FIFO
diff --git a/trio/_highlevel_open_unix_stream.py b/trio/_highlevel_open_unix_stream.py
new file mode 100644
index 0000000000..522ddb7104
--- /dev/null
+++ b/trio/_highlevel_open_unix_stream.py
@@ -0,0 +1,42 @@
+import trio
+from trio._highlevel_open_tcp_stream import close_on_error
+from trio.socket import socket, SOCK_STREAM
+
+try:
+ from trio.socket import AF_UNIX
+ has_unix = True
+except ImportError:
+ has_unix = False
+
+__all__ = ["open_unix_socket"]
+
+
+async def open_unix_socket(filename,):
+ """Opens a connection to the specified
+ `Unix domain socket `__.
+
+ You must have read/write permission on the specified file to connect.
+
+ Args:
+ filename (str or bytes): The filename to open the connection to.
+
+ Returns:
+ SocketStream: a :class:`~trio.abc.Stream` connected to the given file.
+
+ Raises:
+ OSError: If the socket file could not be connected to.
+ RuntimeError: If AF_UNIX sockets are not supported.
+ """
+ if not has_unix:
+ raise RuntimeError("Unix sockets are not supported on this platform")
+
+ if filename is None:
+ raise ValueError("Filename cannot be None")
+
+ # much more simplified logic vs tcp sockets - one socket type and only one
+ # possible location to connect to
+ sock = socket(AF_UNIX, SOCK_STREAM)
+ with close_on_error(sock):
+ await sock.connect(filename)
+
+ return trio.SocketStream(sock)
diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py
index 903f5ecabb..5d00e44c46 100644
--- a/trio/_highlevel_socket.py
+++ b/trio/_highlevel_socket.py
@@ -43,9 +43,9 @@ class SocketStream(HalfCloseableStream):
socket: The trio socket object to wrap. Must have type ``SOCK_STREAM``,
and be connected.
- By default, :class:`SocketStream` enables ``TCP_NODELAY``, and (on
- platforms where it's supported) enables ``TCP_NOTSENT_LOWAT`` with a
- reasonable buffer size (currently 16 KiB) – see `issue #72
+ By default for TCP sockets, :class:`SocketStream` enables ``TCP_NODELAY``,
+ and (on platforms where it's supported) enables ``TCP_NOTSENT_LOWAT`` with
+ a reasonable buffer size (currently 16 KiB) – see `issue #72
`__ for discussion. You can
of course override these defaults by calling :meth:`setsockopt`.
diff --git a/trio/_socket.py b/trio/_socket.py
index d2b0a1a8c6..4ad58b1493 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -1,15 +1,14 @@
-from functools import wraps as _wraps, partial as _partial
+import os as _os
import socket as _stdlib_socket
import sys as _sys
-import os as _os
-from contextlib import contextmanager as _contextmanager
-import errno as _errno
+from functools import wraps as _wraps
import idna as _idna
from . import _core
from ._deprecate import deprecated
from ._threads import run_sync_in_worker_thread
+from ._util import fspath
__all__ = []
@@ -462,8 +461,10 @@ def dup(self):
async def bind(self, address):
await _core.checkpoint()
address = await self._resolve_local_address(address)
- if (hasattr(_stdlib_socket, "AF_UNIX") and self.family == AF_UNIX
- and address[0]):
+ if (
+ hasattr(_stdlib_socket, "AF_UNIX") and self.family == AF_UNIX
+ and address[0]
+ ):
# Use a thread for the filesystem traversal (unless it's an
# abstract domain socket)
return await run_sync_in_worker_thread(self._sock.bind, address)
@@ -504,6 +505,11 @@ async def _resolve_address(self, address, flags):
"address should be a (host, port, [flowinfo, [scopeid]]) "
"tuple"
)
+ elif self._sock.family == AF_UNIX:
+ await _core.checkpoint()
+ # unwrap path-likes
+ return fspath(address)
+
else:
await _core.checkpoint()
return address
diff --git a/trio/_ssl.py b/trio/_ssl.py
index 1f89e880c9..8e1c337c2e 100644
--- a/trio/_ssl.py
+++ b/trio/_ssl.py
@@ -630,9 +630,11 @@ async def receive_some(self, max_bytes):
# For some reason, EOF before handshake sometimes raises
# SSLSyscallError instead of SSLEOFError (e.g. on my linux
# laptop, but not on appveyor). Thanks openssl.
- if (self._https_compatible
- and isinstance(exc.__cause__,
- (SSLEOFError, SSLSyscallError))):
+ if (
+ self._https_compatible and
+ isinstance(exc.__cause__,
+ (SSLEOFError, SSLSyscallError))
+ ):
return b""
else:
raise
@@ -647,8 +649,10 @@ async def receive_some(self, max_bytes):
# BROKEN. But that's actually fine, because after getting an
# EOF on TLS then the only thing you can do is close the
# stream, and closing doesn't care about the state.
- if (self._https_compatible
- and isinstance(exc.__cause__, SSLEOFError)):
+ if (
+ self._https_compatible
+ and isinstance(exc.__cause__, SSLEOFError)
+ ):
return b""
else:
raise
diff --git a/trio/_util.py b/trio/_util.py
index 7cae6d0fe9..e7e02a6df1 100644
--- a/trio/_util.py
+++ b/trio/_util.py
@@ -3,6 +3,7 @@
import os
import sys
from functools import wraps
+import typing as t
import async_generator
@@ -14,11 +15,8 @@
from . import _core
__all__ = [
- "signal_raise",
- "aiter_compat",
- "acontextmanager",
- "ConflictDetector",
- "fixup_module_metadata",
+ "signal_raise", "aiter_compat", "acontextmanager", "ConflictDetector",
+ "fixup_module_metadata", "fspath"
]
# Equivalent to the C function raise(), which Python doesn't wrap
@@ -59,6 +57,7 @@
# - generating synthetic signals for tests
# and for both of those purposes, 'raise' works fine.
import cffi
+
_ffi = cffi.FFI()
_ffi.cdef("int raise(int);")
_lib = _ffi.dlopen("api-ms-win-crt-runtime-l1-1-0.dll")
@@ -133,8 +132,11 @@ async def __aexit__(self, type, value, traceback):
# Likewise, avoid suppressing if a StopIteration exception
# was passed to throw() and later wrapped into a RuntimeError
# (see PEP 479).
- if (isinstance(value, (StopIteration, StopAsyncIteration))
- and exc.__cause__ is value):
+ if (
+ isinstance(value,
+ (StopIteration, StopAsyncIteration))
+ and exc.__cause__ is value
+ ):
return False
raise
except:
@@ -253,3 +255,45 @@ def fix_one(obj):
for objname in namespace["__all__"]:
obj = namespace[objname]
fix_one(obj)
+
+
+# This is copied from PEP 519 as the implementation of os.fspath for
+# Python 3.5. See: https://www.python.org/dev/peps/pep-0519/#os
+# The input typehint is removed as there is no os.PathLike on 3.5.
+
+
+def fspath(path) -> t.Union[str, bytes]:
+ """Return the string representation of the path.
+
+ If str or bytes is passed in, it is returned unchanged. If __fspath__()
+ returns something other than str or bytes then TypeError is raised. If
+ this function is given something that is not str, bytes, or os.PathLike
+ then TypeError is raised.
+ """
+ if isinstance(path, (str, bytes)):
+ return path
+
+ # Work from the object's type to match method resolution of other magic
+ # methods.
+ path_type = type(path)
+ try:
+ path = path_type.__fspath__(path)
+ except AttributeError:
+ if hasattr(path_type, '__fspath__'):
+ raise
+ else:
+ if isinstance(path, (str, bytes)):
+ return path
+ else:
+ raise TypeError(
+ "expected __fspath__() to return str or bytes, "
+ "not " + type(path).__name__
+ )
+
+ raise TypeError(
+ "expected str, bytes or os.PathLike object, not " + path_type.__name__
+ )
+
+
+if hasattr(os, "fspath"):
+ fspath = os.fspath
diff --git a/trio/testing/_checkpoints.py b/trio/testing/_checkpoints.py
index e878d560e1..dbf14566e9 100644
--- a/trio/testing/_checkpoints.py
+++ b/trio/testing/_checkpoints.py
@@ -15,11 +15,19 @@ def _assert_yields_or_not(expected):
try:
yield
finally:
- if (expected and (task._cancel_points == orig_cancel
- or task._schedule_points == orig_schedule)):
+ if (
+ expected and (
+ task._cancel_points == orig_cancel
+ or task._schedule_points == orig_schedule
+ )
+ ):
raise AssertionError("assert_checkpoints block did not yield!")
- elif (not expected and (task._cancel_points != orig_cancel
- or task._schedule_points != orig_schedule)):
+ elif (
+ not expected and (
+ task._cancel_points != orig_cancel
+ or task._schedule_points != orig_schedule
+ )
+ ):
raise AssertionError("assert_no_yields block yielded!")
diff --git a/trio/tests/test_exports.py b/trio/tests/test_exports.py
index a3da988fb3..d5ad1b3381 100644
--- a/trio/tests/test_exports.py
+++ b/trio/tests/test_exports.py
@@ -11,8 +11,10 @@ def test_core_is_properly_reexported():
for symbol in _core.__all__:
found = 0
for source in sources:
- if (symbol in source.__all__
- and getattr(source, symbol) is getattr(_core, symbol)):
+ if (
+ symbol in source.__all__
+ and getattr(source, symbol) is getattr(_core, symbol)
+ ):
found += 1
print(symbol, found)
assert found == 1
diff --git a/trio/tests/test_highlevel_open_unix_stream.py b/trio/tests/test_highlevel_open_unix_stream.py
new file mode 100644
index 0000000000..6a5948b634
--- /dev/null
+++ b/trio/tests/test_highlevel_open_unix_stream.py
@@ -0,0 +1,57 @@
+import os
+import socket
+import tempfile
+
+import pytest
+
+from trio import open_unix_socket, Path
+
+try:
+ from socket import AF_UNIX
+except ImportError:
+ pytestmark = pytest.mark.skip("Needs unix socket support")
+
+
+async def get_server_socket():
+ name = Path() / tempfile.gettempdir() / "test.sock"
+ try:
+ await name.unlink()
+ except OSError:
+ pass
+
+ serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ serv_sock.bind(name.__fspath__()) # bpo-32562
+ serv_sock.listen(1)
+
+ return name, serv_sock
+
+
+async def _do_test_on_sock(serv_sock, unix_socket):
+ # shared code between some tests
+ client, _ = serv_sock.accept()
+ await unix_socket.send_all(b"test")
+ assert client.recv(2048) == b"test"
+
+ client.sendall(b"response")
+ received = await unix_socket.receive_some(2048)
+ assert received == b"response"
+
+
+async def test_open_bad_socket():
+ # mktemp is marked as insecure, but that's okay, we don't want the file to
+ # exist
+ name = os.path.join(tempfile.gettempdir(), tempfile.mktemp())
+ with pytest.raises(FileNotFoundError):
+ await open_unix_socket(name)
+
+
+async def test_open_unix_socket():
+ name, serv_sock = await get_server_socket()
+ unix_socket = await open_unix_socket(name.__fspath__())
+ await _do_test_on_sock(serv_sock, unix_socket)
+
+
+async def test_open_unix_socket_with_path():
+ name, serv_sock = await get_server_socket()
+ unix_socket = await open_unix_socket(name)
+ await _do_test_on_sock(serv_sock, unix_socket)
\ No newline at end of file
diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py
index fced9558b4..371e519910 100644
--- a/trio/tests/test_ssl.py
+++ b/trio/tests/test_ssl.py
@@ -66,9 +66,12 @@
# will hopefully be fixed soon
import sys
WORKAROUND_PYPY_BUG = False
-if (hasattr(sys, "pypy_version_info") and
- ((sys.pypy_version_info < (5, 9)) or
- (sys.pypy_version_info[:4] == (5, 9, 0, "alpha")))):
+if (
+ hasattr(sys, "pypy_version_info") and (
+ (sys.pypy_version_info < (5, 9)) or
+ (sys.pypy_version_info[:4] == (5, 9, 0, "alpha"))
+ )
+):
WORKAROUND_PYPY_BUG = True