Skip to content

Commit

Permalink
Move the socket type fixup code out of the class and into a function
Browse files Browse the repository at this point in the history
I don't want this in the "public" socket interface, because it
interferes with python-triogh-170.
  • Loading branch information
njsmith committed Jul 25, 2017
1 parent a1ad278 commit 6b0eeb6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion trio/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class SocketStream(HalfCloseableStream):
def __init__(self, sock):
if not tsocket.is_trio_socket(sock):
raise TypeError("SocketStream requires trio socket object")
if sock._real_type != tsocket.SOCK_STREAM:
if tsocket._real_type(sock.type) != tsocket.SOCK_STREAM:
raise ValueError("SocketStream requires a SOCK_STREAM socket")
try:
sock.getpeername()
Expand Down
17 changes: 9 additions & 8 deletions trio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,13 @@ def is_trio_socket(obj):
getattr(_stdlib_socket, "SOCK_NONBLOCK", 0)
| getattr(_stdlib_socket, "SOCK_CLOEXEC", 0))

# Hopefully Python will eventually make something like this public
# (see bpo-21327) but I don't want to make it public myself and then
# find out they picked a different name... this is used internally in
# this file and also elsewhere in trio.
def _real_type(type_num):
return type_num & _SOCK_TYPE_MASK

class _SocketType:
def __init__(self, sock):
if type(sock) is not _stdlib_socket.socket:
Expand All @@ -241,12 +248,6 @@ def __init__(self, sock):
self._sock.setblocking(False)
self._did_SHUT_WR = False

# Hopefully Python will eventually make something like this public
# (see bpo-21327) but I don't want to make it public myself and then
# find out they picked a different name... this is used internally in
# this file and also elsewhere in trio.
self._real_type = sock.type & _SOCK_TYPE_MASK

# Defaults:
if self._sock.family == AF_INET6:
try:
Expand Down Expand Up @@ -362,7 +363,7 @@ def _check_address(self, address, *, require_resolved):
_stdlib_socket.getaddrinfo(
address[0], address[1],
self._sock.family,
self._real_type,
_real_type(self._sock.type),
self._sock.proto,
flags=_NUMERIC_ONLY)
except gaierror as exc:
Expand Down Expand Up @@ -399,7 +400,7 @@ async def _resolve_address(self, address, flags):
gai_res = await getaddrinfo(
address[0], address[1],
self._sock.family,
self._real_type,
_real_type(self._sock.type),
self._sock.proto,
flags)
# AFAICT from the spec it's not possible for getaddrinfo to return an
Expand Down

0 comments on commit 6b0eeb6

Please sign in to comment.