diff --git a/trio/_channel.py b/trio/_channel.py index 1cecc55621..f3fef1699e 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -1,6 +1,19 @@ +from __future__ import annotations + from collections import deque, OrderedDict +from collections.abc import Callable from math import inf +from types import TracebackType +from typing import ( + Any, + Generic, + NoReturn, + TypeVar, + TYPE_CHECKING, + Tuple, # only needed for typechecking on <3.9 +) + import attr from outcome import Error, Value @@ -8,11 +21,29 @@ from ._util import generic_function, NoPublicConstructor import trio -from ._core import enable_ki_protection +from ._core import enable_ki_protection, Task, Abort +from ._core._traps import _RaiseCancelT + +# A regular invariant generic type +T = TypeVar("T") + +# The type of object produced by a ReceiveChannel (covariant because +# ReceiveChannel[Derived] can be passed to someone expecting +# ReceiveChannel[Base]) +ReceiveType = TypeVar("ReceiveType", covariant=True) + +# The type of object accepted by a SendChannel (contravariant because +# SendChannel[Base] can be passed to someone expecting +# SendChannel[Derived]) +SendType = TypeVar("SendType", contravariant=True) +# Temporary TypeVar needed until mypy release supports Self as a type +SelfT = TypeVar("SelfT") -@generic_function -def open_memory_channel(max_buffer_size): + +def _open_memory_channel( + max_buffer_size: int, +) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: """Open a channel for passing objects between tasks within a process. Memory channels are lightweight, cheap to allocate, and entirely @@ -68,36 +99,57 @@ def open_memory_channel(max_buffer_size): raise TypeError("max_buffer_size must be an integer or math.inf") if max_buffer_size < 0: raise ValueError("max_buffer_size must be >= 0") - state = MemoryChannelState(max_buffer_size) + state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size) return ( - MemorySendChannel._create(state), - MemoryReceiveChannel._create(state), + MemorySendChannel[T]._create(state), + MemoryReceiveChannel[T]._create(state), ) +# This workaround requires python3.9+, once older python versions are not supported +# or there's a better way of achieving type-checking on a generic factory function, +# it could replace the normal function header +if TYPE_CHECKING: + # written as a class so you can say open_memory_channel[int](5) + # Need to use Tuple instead of tuple due to CI check running on 3.8 + class open_memory_channel(Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]): + def __new__( # type: ignore[misc] # "must return a subtype" + cls, max_buffer_size: int + ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: + return _open_memory_channel(max_buffer_size) + + def __init__(self, max_buffer_size: int): + ... + +else: + # apply the generic_function decorator to make open_memory_channel indexable + # so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime + open_memory_channel = generic_function(_open_memory_channel) + + @attr.s(frozen=True, slots=True) class MemoryChannelStats: - current_buffer_used = attr.ib() - max_buffer_size = attr.ib() - open_send_channels = attr.ib() - open_receive_channels = attr.ib() - tasks_waiting_send = attr.ib() - tasks_waiting_receive = attr.ib() + current_buffer_used: int = attr.ib() + max_buffer_size: int = attr.ib() + open_send_channels: int = attr.ib() + open_receive_channels: int = attr.ib() + tasks_waiting_send: int = attr.ib() + tasks_waiting_receive: int = attr.ib() @attr.s(slots=True) -class MemoryChannelState: - max_buffer_size = attr.ib() - data = attr.ib(factory=deque) +class MemoryChannelState(Generic[T]): + max_buffer_size: int = attr.ib() + data: deque[T] = attr.ib(factory=deque) # Counts of open endpoints using this state - open_send_channels = attr.ib(default=0) - open_receive_channels = attr.ib(default=0) + open_send_channels: int = attr.ib(default=0) + open_receive_channels: int = attr.ib(default=0) # {task: value} - send_tasks = attr.ib(factory=OrderedDict) + send_tasks: OrderedDict[Task, T] = attr.ib(factory=OrderedDict) # {task: None} - receive_tasks = attr.ib(factory=OrderedDict) + receive_tasks: OrderedDict[Task, None] = attr.ib(factory=OrderedDict) - def statistics(self): + def statistics(self) -> MemoryChannelStats: return MemoryChannelStats( current_buffer_used=len(self.data), max_buffer_size=self.max_buffer_size, @@ -109,28 +161,28 @@ def statistics(self): @attr.s(eq=False, repr=False) -class MemorySendChannel(SendChannel, metaclass=NoPublicConstructor): - _state = attr.ib() - _closed = attr.ib(default=False) +class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor): + _state: MemoryChannelState[SendType] = attr.ib() + _closed: bool = attr.ib(default=False) # This is just the tasks waiting on *this* object. As compared to # self._state.send_tasks, which includes tasks from this object and # all clones. - _tasks = attr.ib(factory=set) + _tasks: set[Task] = attr.ib(factory=set) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._state.open_send_channels += 1 - def __repr__(self): + def __repr__(self) -> str: return "".format( id(self), id(self._state) ) - def statistics(self): + def statistics(self) -> MemoryChannelStats: # XX should we also report statistics specific to this object? return self._state.statistics() @enable_ki_protection - def send_nowait(self, value): + def send_nowait(self, value: SendType) -> None: """Like `~trio.abc.SendChannel.send`, but if the channel's buffer is full, raises `WouldBlock` instead of blocking. @@ -150,7 +202,7 @@ def send_nowait(self, value): raise trio.WouldBlock @enable_ki_protection - async def send(self, value): + async def send(self, value: SendType) -> None: """See `SendChannel.send `. Memory channels allow multiple tasks to call `send` at the same time. @@ -170,15 +222,16 @@ async def send(self, value): self._state.send_tasks[task] = value task.custom_sleep_data = self - def abort_fn(_): + def abort_fn(_: _RaiseCancelT) -> Abort: self._tasks.remove(task) del self._state.send_tasks[task] return trio.lowlevel.Abort.SUCCEEDED await trio.lowlevel.wait_task_rescheduled(abort_fn) + # Return type must be stringified or use a TypeVar @enable_ki_protection - def clone(self): + def clone(self) -> "MemorySendChannel[SendType]": """Clone this send channel object. This returns a new `MemorySendChannel` object, which acts as a @@ -206,14 +259,19 @@ def clone(self): raise trio.ClosedResourceError return MemorySendChannel._create(self._state) - def __enter__(self): + def __enter__(self: SelfT) -> SelfT: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: self.close() @enable_ki_protection - def close(self): + def close(self) -> None: """Close this send channel object synchronously. All channel objects have an asynchronous `~.AsyncResource.aclose` method. @@ -241,30 +299,30 @@ def close(self): self._state.receive_tasks.clear() @enable_ki_protection - async def aclose(self): + async def aclose(self) -> None: self.close() await trio.lowlevel.checkpoint() @attr.s(eq=False, repr=False) -class MemoryReceiveChannel(ReceiveChannel, metaclass=NoPublicConstructor): - _state = attr.ib() - _closed = attr.ib(default=False) - _tasks = attr.ib(factory=set) +class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor): + _state: MemoryChannelState[ReceiveType] = attr.ib() + _closed: bool = attr.ib(default=False) + _tasks: set[trio._core._run.Task] = attr.ib(factory=set) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._state.open_receive_channels += 1 - def statistics(self): + def statistics(self) -> MemoryChannelStats: return self._state.statistics() - def __repr__(self): + def __repr__(self) -> str: return "".format( id(self), id(self._state) ) @enable_ki_protection - def receive_nowait(self): + def receive_nowait(self) -> ReceiveType: """Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing ready to receive, raises `WouldBlock` instead of blocking. @@ -284,7 +342,7 @@ def receive_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def receive(self): + async def receive(self) -> ReceiveType: """See `ReceiveChannel.receive `. Memory channels allow multiple tasks to call `receive` at the same @@ -306,15 +364,17 @@ async def receive(self): self._state.receive_tasks[task] = None task.custom_sleep_data = self - def abort_fn(_): + def abort_fn(_: _RaiseCancelT) -> Abort: self._tasks.remove(task) del self._state.receive_tasks[task] return trio.lowlevel.Abort.SUCCEEDED - return await trio.lowlevel.wait_task_rescheduled(abort_fn) + # Not strictly guaranteed to return ReceiveType, but will do so unless + # you intentionally reschedule with a bad value. + return await trio.lowlevel.wait_task_rescheduled(abort_fn) # type: ignore[no-any-return] @enable_ki_protection - def clone(self): + def clone(self) -> "MemoryReceiveChannel[ReceiveType]": """Clone this receive channel object. This returns a new `MemoryReceiveChannel` object, which acts as a @@ -345,14 +405,19 @@ def clone(self): raise trio.ClosedResourceError return MemoryReceiveChannel._create(self._state) - def __enter__(self): + def __enter__(self: SelfT) -> SelfT: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: self.close() @enable_ki_protection - def close(self): + def close(self) -> None: """Close this receive channel object synchronously. All channel objects have an asynchronous `~.AsyncResource.aclose` method. @@ -381,6 +446,6 @@ def close(self): self._state.data.clear() @enable_ki_protection - async def aclose(self): + async def aclose(self) -> None: self.close() await trio.lowlevel.checkpoint() diff --git a/trio/_core/_traps.py b/trio/_core/_traps.py index 95cf46de9b..7611894ce2 100644 --- a/trio/_core/_traps.py +++ b/trio/_core/_traps.py @@ -8,6 +8,7 @@ from . import _run +from typing import Callable, NoReturn, Any # Helper for the bottommost 'yield'. You can't use 'yield' inside an async # function, but you can inside a generator, and if you decorate your generator @@ -64,7 +65,11 @@ class WaitTaskRescheduled: abort_func = attr.ib() -async def wait_task_rescheduled(abort_func): +_RaiseCancelT = Callable[[], NoReturn] # TypeAlias + +# Should always return the type a Task "expects", unless you willfully reschedule it +# with a bad value. +async def wait_task_rescheduled(abort_func: Callable[[_RaiseCancelT], Abort]) -> Any: """Put the current task to sleep, with cancellation support. This is the lowest-level API for blocking in Trio. Every time a