diff --git a/docs/source/reference-hazmat.rst b/docs/source/reference-hazmat.rst index d78a04d835..bb612ee2fa 100644 --- a/docs/source/reference-hazmat.rst +++ b/docs/source/reference-hazmat.rst @@ -492,6 +492,13 @@ this does serve to illustrate the basic structure of the trio.hazmat.reschedule(woken_task) +Low-level cancellation control +------------------------------ + +.. autofunction:: batch_cancellations() + :with: + + Task API -------- diff --git a/newsfragments/897.feature.rst b/newsfragments/897.feature.rst new file mode 100644 index 0000000000..66553051a0 --- /dev/null +++ b/newsfragments/897.feature.rst @@ -0,0 +1,4 @@ +Added :func:`trio.hazmat.batch_cancellations`, allowing user-defined +cancellation abstractions to cancel multiple cancel scopes "at the +same time" with the same semantics (outermost one wins) as when +Trio does this itself. diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 98b44039b8..f2c8a5f256 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -342,15 +342,6 @@ def shield(self, new_value): for task in self._tasks: task._attempt_delivery_of_any_pending_cancel() - def _cancel_no_notify(self): - # returns the affected tasks - if not self.cancel_called: - with self._might_change_effective_deadline(): - self.cancel_called = True - return self._tasks - else: - return set() - @enable_ki_protection def cancel(self): """Cancels this scope immediately. @@ -358,7 +349,11 @@ def cancel(self): This method is idempotent, i.e., if the scope was already cancelled then this method silently does nothing. """ - for task in self._cancel_no_notify(): + if self.cancel_called: + return + with self._might_change_effective_deadline(): + self.cancel_called = True + for task in self._tasks: task._attempt_delivery_of_any_pending_cancel() def _add_task(self, task): @@ -756,6 +751,9 @@ def _attempt_abort(self, raise_cancel): self._runner.reschedule(self, capture(raise_cancel)) def _attempt_delivery_of_any_pending_cancel(self): + if self._runner.cancel_batch is not None: + self._runner.cancel_batch.add(self) + return if self._abort_func is None: return pending_scope = self._pending_cancel_scope() @@ -811,6 +809,10 @@ class Runner: # attached to at least one task deadlines = attr.ib(default=attr.Factory(SortedDict)) + # If not None, we're in a batch_cancellations() scope and this + # collects all the tasks that become cancelled + cancel_batch = attr.ib(default=None) + init_task = attr.ib(default=None) system_nursery = attr.ib(default=None) system_context = attr.ib(default=None) @@ -1156,6 +1158,88 @@ async def init(self, async_fn, args): system_nursery.cancel_scope.cancel() self.entry_queue.spawn() + ################ + # Cancellation + ################ + + @_public + @contextmanager + def batch_cancellations(self): + """A context manager which defers all cancellation delivery + until the context is exited. + + Suppose some task is sleeping within multiple cancel scopes:: + + async def some_task(*, task_status): + with trio.CancelScope() as outer: + with trio.CancelScope() as inner: + task_status.started((outer, inner)) + await trio.sleep_forever() + print("inner scope cancelled") + return + print("outer scope cancelled") + + If ``outer`` and ``inner`` both become cancelled "simultaneously", + there's a question of which one the cancellation should propagate to. + If the cancellations occur due to deadline expiry, the outer scope + wins:: + + async with trio.open_nursery() as nursery: + outer, inner = await nursery.start(some_task) + now = trio.current_time() + inner.deadline = now - 0.1 + outer.deadline = now + # prints: outer scope cancelled + + But if the cancellations occur due to explicit calls to + :meth:`trio.CancelScope.cancel`, whichever one was called first wins:: + + async with trio.open_nursery() as nursery: + outer, inner = await nursery.start(some_task) + inner.cancel() + outer.cancel() + # prints: inner scope cancelled + + This is because Trio doesn't know that you'll also be calling + ``outer.cancel()`` when it wakes up ``some_task`` as a result of + your call to ``inner.cancel()``. + + If you use :func:`batch_cancellations`, all + cancellation-related task wakeups made within the + :func:`batch_cancellations` context get buffered up and + applied as a unit once the context is exited, with outer + scopes once again taking precedence over inner ones. + + :: + + async with trio.open_nursery() as nursery: + outer, inner = await nursery.start(some_task) + with trio.hazmat.batch_cancellations(): + inner.cancel() + outer.cancel() + # prints: outer scope cancelled + + .. warning:: This is a low-level interface intended to aid in the + creation of higher-level cancellation utilities. Nesting of + :meth:`batch_cancellations` contexts is not supported, and + executing any checkpoints within a :meth:`batch_cancellations` + context is liable to crash or deadlock your program. + + """ + if self.cancel_batch is not None: + raise RuntimeError( + "can't nest calls to batch_cancellations() -- did you " + "execute a checkpoint within one?" + ) + self.cancel_batch = set() + try: + yield + finally: + tasks = self.cancel_batch + self.cancel_batch = None + for task in tasks: + task._attempt_delivery_of_any_pending_cancel() + ################ # Outside context problems ################ @@ -1527,17 +1611,15 @@ def run_impl(runner, async_fn, args): # We process all timeouts in a batch and then notify tasks at the end # to ensure that if multiple timeouts occur at once, then it's the # outermost one that gets delivered. - cancelled_tasks = set() - while runner.deadlines: - (deadline, _), cancel_scope = runner.deadlines.peekitem(0) - if deadline <= now: - # This removes the given scope from runner.deadlines: - cancelled_tasks.update(cancel_scope._cancel_no_notify()) - idle_primed = False - else: - break - for task in cancelled_tasks: - task._attempt_delivery_of_any_pending_cancel() + with runner.batch_cancellations(): + while runner.deadlines: + (deadline, _), cancel_scope = runner.deadlines.peekitem(0) + if deadline <= now: + # This removes the given scope from runner.deadlines: + cancel_scope.cancel() + idle_primed = False + else: + break if not runner.runq and idle_primed: while runner.waiting_for_idle: diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py index c1d2aa0b8f..a867af2b9e 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/tests/test_run.py @@ -769,6 +769,57 @@ async def sleeper(): assert record == ["sleeping", "cancelled"] +async def test_batch_cancellations(): + record = [] + + async def some_task(*, task_status): + with _core.CancelScope() as outer: + with _core.CancelScope() as inner: + task_status.started((outer, inner)) + await sleep_forever() + record.append("inner") + return + record.append("outer") + + async with _core.open_nursery() as nursery: + outer, inner = await nursery.start(some_task) + now = _core.current_time() + inner.deadline = now - 0.1 + outer.deadline = now + assert record[-1] == "outer" + + async with _core.open_nursery() as nursery: + outer, inner = await nursery.start(some_task) + inner.cancel() + outer.cancel() + assert record[-1] == "inner" + + async with _core.open_nursery() as nursery: + outer, inner = await nursery.start(some_task) + with _core.batch_cancellations(): + inner.cancel() + outer.cancel() + assert record[-1] == "outer" + + with _core.batch_cancellations(): + with pytest.raises(RuntimeError): + with _core.batch_cancellations(): + pass # pragma: no cover + + +def test_batch_cancellations_with_improper_yield(): + async def evil(): + with _core.batch_cancellations(): + await _core.checkpoint() + + with pytest.raises(_core.TrioInternalError) as exc_info: + _core.run(evil) + message = str(exc_info.value.__cause__) + assert "batch_cancellations" in message and "checkpoint" in message + + gc_collect_harder() + + async def test_basic_timeout(mock_clock): start = _core.current_time() with _core.CancelScope() as scope: diff --git a/trio/hazmat.py b/trio/hazmat.py index 2c19bd81ad..3b72f12141 100644 --- a/trio/hazmat.py +++ b/trio/hazmat.py @@ -20,7 +20,7 @@ current_statistics, reschedule, remove_instrument, add_instrument, current_clock, current_root_task, checkpoint_if_cancelled, spawn_system_task, wait_socket_readable, wait_socket_writable, - notify_socket_close + notify_socket_close, batch_cancellations ) # Unix-specific symbols