Skip to content

Commit

Permalink
Add CancelScope.cancelled_caught
Browse files Browse the repository at this point in the history
Added `CancelScope.cancelled_caught` to match the Trio API. Relevant
Trio documentation
[here](https://trio.readthedocs.io/en/stable/reference-core.html#trio.CancelScope.cancelled_caught).

Refactored `TaskGroup.__aexit__` to match Trio's structured concurrency
model.

Closes agronholm#257
  • Loading branch information
johnzeringue committed May 12, 2021
1 parent b8ab9b6 commit ec3d18c
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ __pycache__
.cache
.local
.pre-commit-config.yaml
.vscode/
46 changes: 34 additions & 12 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
self._deadline = deadline
self._shield = shield
self._parent_scope: Optional[CancelScope] = None
self._cancelled_caught = False
self._cancel_called = False
self._active = False
self._timeout_handle: Optional[asyncio.TimerHandle] = None
Expand All @@ -262,6 +263,9 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
self._host_task: Optional[asyncio.Task] = None
self._timeout_expired = False

def _child_tasks(self):
return self._tasks - set([self._host_task])

def __enter__(self):
if self._active:
raise RuntimeError(
Expand All @@ -284,8 +288,7 @@ def __enter__(self):
self._active = True
return self

def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> Optional[bool]:
def _close(self, exc_val: Optional[BaseException]) -> Optional[bool]:
if not self._active:
raise RuntimeError('This cancel scope is not active')
if current_task() is not self._host_task:
Expand Down Expand Up @@ -315,6 +318,8 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[Ba
if exc_val is not None:
exceptions = exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val]
if all(isinstance(exc, CancelledError) for exc in exceptions):
self._cancelled_caught = self._cancel_called and not self._parent_cancelled()

if self._timeout_expired:
return True
elif not self._cancel_called:
Expand All @@ -326,6 +331,10 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[Ba

return None

def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> Optional[bool]:
return self._close(exc_val)

def _timeout(self):
if self._deadline != math.inf:
loop = get_running_loop()
Expand Down Expand Up @@ -391,17 +400,25 @@ def _parent_cancelled(self) -> bool:

return False

def cancel(self) -> DeprecatedAwaitable:
def _do_cancel(self, *, cancel_called) -> DeprecatedAwaitable:
if not self._cancel_called:
if self._timeout_handle:
self._timeout_handle.cancel()
self._timeout_handle = None

self._cancel_called = True
self._cancel_called |= cancel_called
self._deliver_cancellation()

return DeprecatedAwaitable(self.cancel)

def cancel(self) -> DeprecatedAwaitable:
return self._do_cancel(cancel_called=True)

def _cancel_exc(self, exc_val: BaseException) -> None:
maybe_native_cancellation = isinstance(
exc_val, CancelledError) and not self._parent_cancelled()
self._do_cancel(cancel_called=not maybe_native_cancellation)

@property
def deadline(self) -> float:
return self._deadline
Expand All @@ -416,6 +433,10 @@ def deadline(self, value: float) -> None:
if self._active and not self._cancel_called:
self._timeout()

@property
def cancelled_caught(self) -> bool:
return self._cancelled_caught

@property
def cancel_called(self) -> bool:
return self._cancel_called
Expand Down Expand Up @@ -535,16 +556,17 @@ async def __aenter__(self):
async def __aexit__(self, exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> Optional[bool]:
ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
if exc_val is not None:
self.cancel_scope.cancel()
self.cancel_scope._cancel_exc(exc_val)
self._exceptions.append(exc_val)

while self.cancel_scope._tasks:
while self.cancel_scope._child_tasks():
try:
await asyncio.wait(self.cancel_scope._tasks)
except asyncio.CancelledError:
self.cancel_scope.cancel()
await asyncio.wait(self.cancel_scope._child_tasks())
except asyncio.CancelledError as exc:
self.cancel_scope._cancel_exc(exc)

ignore_exception = self.cancel_scope._close(ExceptionGroup(self._exceptions))

self._active = False
if not self.cancel_scope._parent_cancelled():
Expand Down Expand Up @@ -601,7 +623,7 @@ async def _run_wrapped_task(
except BaseException as exc:
if task_status_future is None or task_status_future.done():
self._exceptions.append(exc)
self.cancel_scope.cancel()
self.cancel_scope._cancel_exc(exc)
else:
task_status_future.set_exception(exc)
else:
Expand Down Expand Up @@ -632,7 +654,7 @@ def task_done(_task: asyncio.Task) -> None:
if exc is not None:
if task_status_future is None or task_status_future.done():
self._exceptions.append(exc)
self.cancel_scope.cancel()
self.cancel_scope._cancel_exc(exc)
else:
task_status_future.set_exception(exc)
elif task_status_future is not None and not task_status_future.done():
Expand Down
4 changes: 4 additions & 0 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def deadline(self) -> float:
def deadline(self, value: float) -> None:
self.__original.deadline = value

@property
def cancelled_caught(self) -> bool:
return self.__original.cancelled_caught

@property
def cancel_called(self) -> bool:
return self.__original.cancel_called
Expand Down
5 changes: 5 additions & 0 deletions src/anyio/_core/_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def deadline(self) -> float:
def deadline(self, value: float) -> None:
raise NotImplementedError

@property
def cancelled_caught(self) -> bool:
"""Records whether this scope caught a ``CancelledError``."""
raise NotImplementedError

@property
def cancel_called(self) -> bool:
"""``True`` if :meth:`cancel` has been called."""
Expand Down
84 changes: 84 additions & 0 deletions tests/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,27 @@ async def start_another():
assert not finished


@pytest.mark.parametrize('anyio_backend', ['asyncio'])
async def test_start_native_host_cancelled_cancel_scope():
async def taskfunc():
await sleep(2)

async def start_another():
nonlocal cancel_scope
with CancelScope() as cancel_scope:
await taskfunc()

cancel_scope = None
task = asyncio.get_event_loop().create_task(start_another())
await wait_all_tasks_blocked()
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task

assert not cancel_scope.cancelled_caught
assert not cancel_scope.cancel_called


@pytest.mark.parametrize('anyio_backend', ['asyncio'])
async def test_start_native_child_cancelled():
async def taskfunc(*, task_status):
Expand Down Expand Up @@ -395,6 +416,7 @@ async def test_fail_after(delay):
with fail_after(delay) as scope:
await sleep(1)

assert scope.cancelled_caught
assert scope.cancel_called


Expand All @@ -403,6 +425,7 @@ async def test_fail_after_no_timeout():
assert scope.deadline == float('inf')
await sleep(0.1)

assert not scope.cancelled_caught
assert not scope.cancel_called


Expand All @@ -429,6 +452,7 @@ async def test_move_on_after(delay):
result = True

assert not result
assert scope.cancelled_caught
assert scope.cancel_called


Expand All @@ -440,6 +464,7 @@ async def test_move_on_after_no_timeout():
result = True

assert result
assert not scope.cancelled_caught
assert not scope.cancel_called


Expand All @@ -456,7 +481,9 @@ async def test_nested_move_on_after():

assert not sleep_completed
assert not inner_scope_completed
assert outer_scope.cancelled_caught
assert outer_scope.cancel_called
assert not inner_scope.cancelled_caught
assert not inner_scope.cancel_called


Expand All @@ -478,7 +505,9 @@ async def cancel_when_ready():

assert inner_sleep_completed
assert not outer_sleep_completed
assert tg.cancel_scope.cancelled_caught
assert tg.cancel_scope.cancel_called
assert not inner_scope.cancelled_caught
assert not inner_scope.cancel_called


Expand Down Expand Up @@ -581,6 +610,7 @@ async def child():
host_done = True

assert host_done
assert not tg.cancel_scope.cancelled_caught
assert not tg.cancel_scope.cancel_called


Expand Down Expand Up @@ -617,6 +647,29 @@ async def do_something():
tg.cancel_scope.cancel()


async def test_cancel_nested_scopes():
async def do_something():
nonlocal tg2
async with create_task_group() as tg2:
tg2.start_soon(sleep, 1)

raise Exception('foo')

tg2 = None
async with create_task_group() as tg:
tg.start_soon(do_something)
await wait_all_tasks_blocked()
tg.cancel_scope.cancel()

assert not tg.cancel_scope.cancelled_caught
assert tg.cancel_scope.cancel_called

assert tg.cancel_scope.cancelled_caught
assert tg.cancel_scope.cancel_called
assert not tg2.cancel_scope.cancelled_caught
assert tg2.cancel_scope.cancel_called


async def test_cancelled_parent():
async def child():
with CancelScope():
Expand Down Expand Up @@ -677,6 +730,7 @@ async def killer(scope):

pytest.fail('Execution should also not reach this point')

assert scope.cancelled_caught
assert scope.cancel_called


Expand Down Expand Up @@ -948,3 +1002,33 @@ async def taskfunc(*, task_status):
tg.start_soon(sleep, 5)
await wait_all_tasks_blocked()
task.cancel('blah')


async def test_cancel_called_but_not_caught():
async def task(*, task_status):
with CancelScope() as scope:
task_status.started(scope)
try:
await sleep(1)
except (asyncio.CancelledError, trio.Cancelled):
pass

async with create_task_group() as tg:
scope = await tg.start(task)
scope.cancel()

assert scope.cancel_called
assert not scope.cancelled_caught


async def test_scope_active_until_children_finish():
async def task(scope):
await sleep(1)

# should raise RuntimeError since scope is already active
with scope:
...

with pytest.raises(RuntimeError):
async with create_task_group() as tg:
tg.start_soon(task, tg.cancel_scope)

0 comments on commit ec3d18c

Please sign in to comment.