Skip to content

Commit

Permalink
Make did_shutdown_SHUT_WR public API on trio socket objects
Browse files Browse the repository at this point in the history
This is more prep for python-triogh-170. As of this commit SocketType no longer
has any secret-but-quasi-public APIs.
  • Loading branch information
njsmith committed Jul 25, 2017
1 parent 6b0eeb6 commit a900240
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
9 changes: 9 additions & 0 deletions docs/source/reference-io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,15 @@ Socket objects

`Not implemented yet! <https://github.com/python-trio/trio/issues/45>`__

We also keep track of an extra bit of state, because it turns out
to be useful for :class:`trio.SocketStream`:

.. attribute:: did_shutdown_SHUT_WR

This :class:`bool` attribute it True if you've called
``sock.shutdown(SHUT_WR)`` or ``sock.shutdown(SHUT_RDWR)``, and
False otherwise.

The following methods are identical to their equivalents in
:func:`socket.socket`, except async, and the ones that take address
arguments require pre-resolved addresses:
Expand Down
4 changes: 2 additions & 2 deletions trio/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self, sock):
pass

async def send_all(self, data):
if self.socket._did_SHUT_WR:
if self.socket.did_shutdown_SHUT_WR:
await _core.yield_briefly()
raise ClosedStreamError("can't send data after sending EOF")
with self._send_lock.sync:
Expand All @@ -112,7 +112,7 @@ async def send_eof(self):
async with self._send_lock:
# On MacOS, calling shutdown a second time raises ENOTCONN, but
# send_eof needs to be idempotent.
if self.socket._did_SHUT_WR:
if self.socket.did_shutdown_SHUT_WR:
return
with _translate_socket_errors_to_stream_errors():
self.socket.shutdown(tsocket.SHUT_WR)
Expand Down
8 changes: 6 additions & 2 deletions trio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def __init__(self, sock):
.format(type(sock).__name__))
self._sock = sock
self._sock.setblocking(False)
self._did_SHUT_WR = False
self._did_shutdown_SHUT_WR = False

# Defaults:
if self._sock.family == AF_INET6:
Expand Down Expand Up @@ -307,6 +307,10 @@ def type(self):
def proto(self):
return self._sock.proto

@property
def did_shutdown_SHUT_WR(self):
return self._did_shutdown_SHUT_WR

def __repr__(self):
return repr(self._sock).replace("socket.socket", "trio.socket.socket")

Expand All @@ -325,7 +329,7 @@ def shutdown(self, flag):
self._sock.shutdown(flag)
# only do this if the call succeeded:
if flag in [SHUT_WR, SHUT_RDWR]:
self._did_SHUT_WR = True
self._did_shutdown_SHUT_WR = True

async def wait_writable(self):
await _core.wait_socket_writable(self._sock)
Expand Down
9 changes: 9 additions & 0 deletions trio/tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,20 @@ async def test_SocketType_shutdown():
with a, b:
await a.sendall(b"xxx")
assert await b.recv(3) == b"xxx"
assert not a.did_shutdown_SHUT_WR
assert not b.did_shutdown_SHUT_WR
a.shutdown(tsocket.SHUT_WR)
assert a.did_shutdown_SHUT_WR
assert not b.did_shutdown_SHUT_WR
assert await b.recv(3) == b""
await b.sendall(b"yyy")
assert await a.recv(3) == b"yyy"

b.shutdown(tsocket.SHUT_RD)
assert not b.did_shutdown_SHUT_WR
b.shutdown(tsocket.SHUT_RDWR)
assert b.did_shutdown_SHUT_WR


@pytest.mark.parametrize("address, socket_type", [('127.0.0.1', tsocket.AF_INET), ('::1', tsocket.AF_INET6)])
async def test_SocketType_simple_server(address, socket_type):
Expand Down

0 comments on commit a900240

Please sign in to comment.