Skip to content

Commit

Permalink
Add CancelScope.cancelled_caught
Browse files Browse the repository at this point in the history
Added `CancelScope.cancelled_caught` to match the Trio API. Relevant
Trio documentation
[here](https://trio.readthedocs.io/en/stable/reference-core.html#trio.CancelScope.cancelled_caught).

Closes agronholm#257
  • Loading branch information
johnzeringue committed May 7, 2021
1 parent 5169f1d commit 891d5f8
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ __pycache__
.cache
.local
.pre-commit-config.yaml
.vscode/
.python-version
9 changes: 9 additions & 0 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
self._deadline = deadline
self._shield = shield
self._parent_scope: Optional[CancelScope] = None
self._cancelled_caught = False
self._cancel_called = False
self._active = False
self._timeout_handle: Optional[asyncio.TimerHandle] = None
Expand Down Expand Up @@ -304,6 +305,8 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[Ba
if exc_val is not None:
exceptions = exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val]
if all(isinstance(exc, CancelledError) for exc in exceptions):
self._cancelled_caught = self._cancel_called and not self._parent_cancelled()

if self._timeout_expired:
return True
elif not self._cancel_called:
Expand All @@ -312,6 +315,8 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[Ba
elif not self._parent_cancelled():
# This scope was directly cancelled
return True
elif self._cancel_called:
self._cancelled_caught = True

return None

Expand Down Expand Up @@ -392,6 +397,10 @@ def deadline(self, value: float) -> None:
if self._active and not self._cancel_called:
self._timeout()

@property
def cancelled_caught(self) -> bool:
return self._cancelled_caught

@property
def cancel_called(self) -> bool:
return self._cancel_called
Expand Down
4 changes: 4 additions & 0 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def deadline(self) -> float:
def deadline(self, value: float) -> None:
self.__original.deadline = value

@property
def cancelled_caught(self) -> bool:
return self.__original.cancelled_caught

@property
def cancel_called(self) -> bool:
return self.__original.cancel_called
Expand Down
5 changes: 5 additions & 0 deletions src/anyio/_core/_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def deadline(self) -> float:
def deadline(self, value: float) -> None:
raise NotImplementedError

@property
def cancelled_caught(self) -> bool:
"""Records whether this scope caught a ``CancelledError``."""
raise NotImplementedError

@property
def cancel_called(self) -> bool:
"""``True`` if :meth:`cancel` has been called."""
Expand Down
64 changes: 64 additions & 0 deletions tests/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@ async def taskfunc(*, task_status):
finished = True

async def start_another():
nonlocal tg
async with create_task_group() as tg:
await tg.start(taskfunc)

tg = None
started = finished = False
task = asyncio.get_event_loop().create_task(start_another())
await wait_all_tasks_blocked()
Expand All @@ -179,6 +181,34 @@ async def start_another():
assert not finished


@pytest.mark.parametrize('anyio_backend', ['asyncio'])
async def test_start_native_host_cancelled_cancel_scope():
async def taskfunc():
nonlocal started, finished
started = True
await sleep(2)
finished = True

async def start_another():
nonlocal cancel_scope
async with CancelScope() as cancel_scope:
await taskfunc()

cancel_scope = None
started = finished = False
task = asyncio.get_event_loop().create_task(start_another())
await wait_all_tasks_blocked()
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task

assert started
assert not finished

assert not cancel_scope.cancelled_caught
assert not cancel_scope.cancel_called


@pytest.mark.parametrize('anyio_backend', ['asyncio'])
async def test_start_native_child_cancelled():
async def taskfunc(*, task_status):
Expand Down Expand Up @@ -395,6 +425,7 @@ async def test_fail_after(delay):
with fail_after(delay) as scope:
await sleep(1)

assert scope.cancelled_caught
assert scope.cancel_called


Expand All @@ -403,6 +434,7 @@ async def test_fail_after_no_timeout():
assert scope.deadline == float('inf')
await sleep(0.1)

assert not scope.cancelled_caught
assert not scope.cancel_called


Expand All @@ -429,6 +461,7 @@ async def test_move_on_after(delay):
result = True

assert not result
assert scope.cancelled_caught
assert scope.cancel_called


Expand All @@ -440,6 +473,7 @@ async def test_move_on_after_no_timeout():
result = True

assert result
assert not scope.cancelled_caught
assert not scope.cancel_called


Expand All @@ -456,6 +490,8 @@ async def test_nested_move_on_after():

assert not sleep_completed
assert not inner_scope_completed
assert outer_scope.cancelled_caught
assert not inner_scope.cancelled_caught
assert outer_scope.cancel_called
assert not inner_scope.cancel_called

Expand All @@ -478,6 +514,8 @@ async def cancel_when_ready():

assert inner_sleep_completed
assert not outer_sleep_completed
assert tg.cancel_scope.cancelled_caught
assert not inner_scope.cancelled_caught
assert tg.cancel_scope.cancel_called
assert not inner_scope.cancel_called

Expand Down Expand Up @@ -558,6 +596,7 @@ async def child():
host_done = True

assert host_done
assert not tg.cancel_scope.cancelled_caught
assert not tg.cancel_scope.cancel_called


Expand All @@ -582,7 +621,11 @@ async def child(fail):


async def test_cancel_cascade():
tg2 = None

async def do_something():
nonlocal tg2

async with create_task_group() as tg2:
tg2.start_soon(sleep, 1)

Expand All @@ -593,6 +636,14 @@ async def do_something():
await wait_all_tasks_blocked()
tg.cancel_scope.cancel()

assert not tg.cancel_scope.cancelled_caught
assert tg.cancel_scope.cancel_called

assert tg.cancel_scope.cancelled_caught
assert tg.cancel_scope.cancel_called
assert not tg2.cancel_scope.cancelled_caught
assert tg2.cancel_scope.cancel_called


async def test_cancelled_parent():
async def child():
Expand Down Expand Up @@ -654,6 +705,7 @@ async def killer(scope):

pytest.fail('Execution should also not reach this point')

assert scope.cancelled_caught
assert scope.cancel_called


Expand Down Expand Up @@ -884,3 +936,15 @@ async def exit_scope(scope):
with pytest.raises(RuntimeError):
async with create_task_group() as tg:
tg.start_soon(exit_scope, scope)


async def test_raised():
class SomeExc(Exception):
pass

with pytest.raises(SomeExc):
async with CancelScope() as cancel_scope:
raise SomeExc

assert not cancel_scope.cancelled_caught
assert not cancel_scope.cancel_called

0 comments on commit 891d5f8

Please sign in to comment.