Skip to content

Commit

Permalink
Add CancelScope.cancelled_caught
Browse files Browse the repository at this point in the history
Added `CancelScope.cancelled_caught` to match the Trio API. Relevant
Trio documentation
[here](https://trio.readthedocs.io/en/stable/reference-core.html#trio.CancelScope.cancelled_caught).

Closes agronholm#257
  • Loading branch information
johnzeringue committed May 10, 2021
1 parent 5169f1d commit f481e1c
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ __pycache__
.cache
.local
.pre-commit-config.yaml
.vscode/
15 changes: 12 additions & 3 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
self._deadline = deadline
self._shield = shield
self._parent_scope: Optional[CancelScope] = None
self._cancelled_caught = False
self._cancel_called = False
self._active = False
self._timeout_handle: Optional[asyncio.TimerHandle] = None
Expand Down Expand Up @@ -304,6 +305,8 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[Ba
if exc_val is not None:
exceptions = exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val]
if all(isinstance(exc, CancelledError) for exc in exceptions):
self._cancelled_caught = self._cancel_called and not self._parent_cancelled()

if self._timeout_expired:
return True
elif not self._cancel_called:
Expand Down Expand Up @@ -392,6 +395,10 @@ def deadline(self, value: float) -> None:
if self._active and not self._cancel_called:
self._timeout()

@property
def cancelled_caught(self) -> bool:
return self._cancelled_caught

@property
def cancel_called(self) -> bool:
return self._cancel_called
Expand Down Expand Up @@ -504,17 +511,19 @@ async def __aenter__(self):
async def __aexit__(self, exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> Optional[bool]:
ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
if exc_val is not None:
self.cancel_scope.cancel()
self._exceptions.append(exc_val)

while self.cancel_scope._tasks:
while self.cancel_scope._tasks - set([self.cancel_scope._host_task]):
try:
await asyncio.wait(self.cancel_scope._tasks)
await asyncio.wait(self.cancel_scope._tasks - set([self.cancel_scope._host_task]))
except asyncio.CancelledError:
self.cancel_scope.cancel()

print(self._exceptions)
ignore_exception = self.cancel_scope.__exit__(exc_type, ExceptionGroup(self._exceptions), exc_tb)

self._active = False
if not self.cancel_scope._parent_cancelled():
exceptions = self._filter_cancellation_errors(self._exceptions)
Expand Down
4 changes: 4 additions & 0 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def deadline(self) -> float:
def deadline(self, value: float) -> None:
self.__original.deadline = value

@property
def cancelled_caught(self) -> bool:
return self.__original.cancelled_caught

@property
def cancel_called(self) -> bool:
return self.__original.cancel_called
Expand Down
5 changes: 5 additions & 0 deletions src/anyio/_core/_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def deadline(self) -> float:
def deadline(self, value: float) -> None:
raise NotImplementedError

@property
def cancelled_caught(self) -> bool:
"""Records whether this scope caught a ``CancelledError``."""
raise NotImplementedError

@property
def cancel_called(self) -> bool:
"""``True`` if :meth:`cancel` has been called."""
Expand Down
74 changes: 74 additions & 0 deletions tests/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,34 @@ async def start_another():
assert not finished


@pytest.mark.parametrize('anyio_backend', ['asyncio'])
async def test_start_native_host_cancelled_cancel_scope():
async def taskfunc():
nonlocal started, finished
started = True
await sleep(2)
finished = True

async def start_another():
nonlocal cancel_scope
async with CancelScope() as cancel_scope:
await taskfunc()

cancel_scope = None
started = finished = False
task = asyncio.get_event_loop().create_task(start_another())
await wait_all_tasks_blocked()
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task

assert started
assert not finished

assert not cancel_scope.cancelled_caught
assert not cancel_scope.cancel_called


@pytest.mark.parametrize('anyio_backend', ['asyncio'])
async def test_start_native_child_cancelled():
async def taskfunc(*, task_status):
Expand Down Expand Up @@ -395,6 +423,7 @@ async def test_fail_after(delay):
with fail_after(delay) as scope:
await sleep(1)

assert scope.cancelled_caught
assert scope.cancel_called


Expand All @@ -403,6 +432,7 @@ async def test_fail_after_no_timeout():
assert scope.deadline == float('inf')
await sleep(0.1)

assert not scope.cancelled_caught
assert not scope.cancel_called


Expand All @@ -429,6 +459,7 @@ async def test_move_on_after(delay):
result = True

assert not result
assert scope.cancelled_caught
assert scope.cancel_called


Expand All @@ -440,6 +471,7 @@ async def test_move_on_after_no_timeout():
result = True

assert result
assert not scope.cancelled_caught
assert not scope.cancel_called


Expand All @@ -456,7 +488,9 @@ async def test_nested_move_on_after():

assert not sleep_completed
assert not inner_scope_completed
assert outer_scope.cancelled_caught
assert outer_scope.cancel_called
assert not inner_scope.cancelled_caught
assert not inner_scope.cancel_called


Expand All @@ -478,7 +512,9 @@ async def cancel_when_ready():

assert inner_sleep_completed
assert not outer_sleep_completed
assert tg.cancel_scope.cancelled_caught
assert tg.cancel_scope.cancel_called
assert not inner_scope.cancelled_caught
assert not inner_scope.cancel_called


Expand Down Expand Up @@ -558,6 +594,7 @@ async def child():
host_done = True

assert host_done
assert not tg.cancel_scope.cancelled_caught
assert not tg.cancel_scope.cancel_called


Expand All @@ -583,16 +620,26 @@ async def child(fail):

async def test_cancel_cascade():
async def do_something():
nonlocal tg2
async with create_task_group() as tg2:
tg2.start_soon(sleep, 1)

raise Exception('foo')

tg2 = None
async with create_task_group() as tg:
tg.start_soon(do_something)
await wait_all_tasks_blocked()
tg.cancel_scope.cancel()

assert not tg.cancel_scope.cancelled_caught
assert tg.cancel_scope.cancel_called

assert tg.cancel_scope.cancelled_caught
assert tg.cancel_scope.cancel_called
assert not tg2.cancel_scope.cancelled_caught
assert tg2.cancel_scope.cancel_called


async def test_cancelled_parent():
async def child():
Expand Down Expand Up @@ -654,6 +701,7 @@ async def killer(scope):

pytest.fail('Execution should also not reach this point')

assert scope.cancelled_caught
assert scope.cancel_called


Expand Down Expand Up @@ -884,3 +932,29 @@ async def exit_scope(scope):
with pytest.raises(RuntimeError):
async with create_task_group() as tg:
tg.start_soon(exit_scope, scope)


async def test_raised():
with pytest.raises(ValueError):
async with CancelScope() as cancel_scope:
raise ValueError

assert not cancel_scope.cancelled_caught
assert not cancel_scope.cancel_called


async def test_cancel_called_but_not_caught():
async def task(*, task_status):
with CancelScope() as scope:
task_status.started(scope)
try:
await sleep(1)
except (asyncio.CancelledError, trio.Cancelled):
pass

async with create_task_group() as tg:
scope = await tg.start(task)
scope.cancel()

assert scope.cancel_called
assert not scope.cancelled_caught

0 comments on commit f481e1c

Please sign in to comment.