diff --git a/.gitignore b/.gitignore index 149a9087..602f148a 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,5 @@ __pycache__ .cache .local .pre-commit-config.yaml +.vscode/ +.python-version diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 68fac2ea..1c779137 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -233,6 +233,7 @@ def __init__(self, deadline: float = math.inf, shield: bool = False): self._deadline = deadline self._shield = shield self._parent_scope: Optional[CancelScope] = None + self._cancelled_caught = False self._cancel_called = False self._active = False self._timeout_handle: Optional[asyncio.TimerHandle] = None @@ -304,6 +305,8 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[Ba if exc_val is not None: exceptions = exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val] if all(isinstance(exc, CancelledError) for exc in exceptions): + self._cancelled_caught = self._cancel_called and not self._parent_cancelled() + if self._timeout_expired: return True elif not self._cancel_called: @@ -312,6 +315,8 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[Ba elif not self._parent_cancelled(): # This scope was directly cancelled return True + elif self._cancel_called: + self._cancelled_caught = True return None @@ -392,6 +397,10 @@ def deadline(self, value: float) -> None: if self._active and not self._cancel_called: self._timeout() + @property + def cancelled_caught(self) -> bool: + return self._cancelled_caught + @property def cancel_called(self) -> bool: return self._cancel_called diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 780d3657..0778715e 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -85,6 +85,10 @@ def deadline(self) -> float: def deadline(self, value: float) -> None: self.__original.deadline = value + @property + def cancelled_caught(self) -> bool: + return self.__original.cancelled_caught + @property def cancel_called(self) -> bool: return self.__original.cancel_called diff --git a/src/anyio/_core/_tasks.py b/src/anyio/_core/_tasks.py index 62230f04..5f24b434 100644 --- a/src/anyio/_core/_tasks.py +++ b/src/anyio/_core/_tasks.py @@ -45,6 +45,11 @@ def deadline(self) -> float: def deadline(self, value: float) -> None: raise NotImplementedError + @property + def cancelled_caught(self) -> bool: + """Records whether this scope caught a ``CancelledError``.""" + raise NotImplementedError + @property def cancel_called(self) -> bool: """``True`` if :meth:`cancel` has been called.""" diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index 17abb566..130c89be 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -165,9 +165,11 @@ async def taskfunc(*, task_status): finished = True async def start_another(): + nonlocal tg async with create_task_group() as tg: await tg.start(taskfunc) + tg = None started = finished = False task = asyncio.get_event_loop().create_task(start_another()) await wait_all_tasks_blocked() @@ -179,6 +181,34 @@ async def start_another(): assert not finished +@pytest.mark.parametrize('anyio_backend', ['asyncio']) +async def test_start_native_host_cancelled_cancel_scope(): + async def taskfunc(): + nonlocal started, finished + started = True + await sleep(2) + finished = True + + async def start_another(): + nonlocal cancel_scope + async with CancelScope() as cancel_scope: + await taskfunc() + + cancel_scope = None + started = finished = False + task = asyncio.get_event_loop().create_task(start_another()) + await wait_all_tasks_blocked() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + assert started + assert not finished + + assert not cancel_scope.cancelled_caught + assert not cancel_scope.cancel_called + + @pytest.mark.parametrize('anyio_backend', ['asyncio']) async def test_start_native_child_cancelled(): async def taskfunc(*, task_status): @@ -395,6 +425,7 @@ async def test_fail_after(delay): with fail_after(delay) as scope: await sleep(1) + assert scope.cancelled_caught assert scope.cancel_called @@ -403,6 +434,7 @@ async def test_fail_after_no_timeout(): assert scope.deadline == float('inf') await sleep(0.1) + assert not scope.cancelled_caught assert not scope.cancel_called @@ -429,6 +461,7 @@ async def test_move_on_after(delay): result = True assert not result + assert scope.cancelled_caught assert scope.cancel_called @@ -440,6 +473,7 @@ async def test_move_on_after_no_timeout(): result = True assert result + assert not scope.cancelled_caught assert not scope.cancel_called @@ -456,6 +490,8 @@ async def test_nested_move_on_after(): assert not sleep_completed assert not inner_scope_completed + assert outer_scope.cancelled_caught + assert not inner_scope.cancelled_caught assert outer_scope.cancel_called assert not inner_scope.cancel_called @@ -478,6 +514,8 @@ async def cancel_when_ready(): assert inner_sleep_completed assert not outer_sleep_completed + assert tg.cancel_scope.cancelled_caught + assert not inner_scope.cancelled_caught assert tg.cancel_scope.cancel_called assert not inner_scope.cancel_called @@ -558,6 +596,7 @@ async def child(): host_done = True assert host_done + assert not tg.cancel_scope.cancelled_caught assert not tg.cancel_scope.cancel_called @@ -582,7 +621,11 @@ async def child(fail): async def test_cancel_cascade(): + tg2 = None + async def do_something(): + nonlocal tg2 + async with create_task_group() as tg2: tg2.start_soon(sleep, 1) @@ -593,6 +636,14 @@ async def do_something(): await wait_all_tasks_blocked() tg.cancel_scope.cancel() + assert not tg.cancel_scope.cancelled_caught + assert tg.cancel_scope.cancel_called + + assert tg.cancel_scope.cancelled_caught + assert tg.cancel_scope.cancel_called + assert not tg2.cancel_scope.cancelled_caught + assert tg2.cancel_scope.cancel_called + async def test_cancelled_parent(): async def child(): @@ -654,6 +705,7 @@ async def killer(scope): pytest.fail('Execution should also not reach this point') + assert scope.cancelled_caught assert scope.cancel_called @@ -884,3 +936,15 @@ async def exit_scope(scope): with pytest.raises(RuntimeError): async with create_task_group() as tg: tg.start_soon(exit_scope, scope) + + +async def test_raised(): + class SomeExc(Exception): + pass + + with pytest.raises(SomeExc): + async with CancelScope() as cancel_scope: + raise SomeExc + + assert not cancel_scope.cancelled_caught + assert not cancel_scope.cancel_called