Skip to content

Commit

Permalink
[3.11] gh-111085: Fix invalid state handling in TaskGroup and Timeout (
Browse files Browse the repository at this point in the history
…GH-111111) (GH-111172)

asyncio.TaskGroup and asyncio.Timeout classes now raise proper RuntimeError
if they are improperly used.

* When they are used without entering the context manager.
* When they are used after finishing.
* When the context manager is entered more than once (simultaneously or
  sequentially).
* If there is no current task when entering the context manager.

They now remain in a consistent state after an exception is thrown,
so subsequent operations can be performed correctly (if they are allowed).

(cherry picked from commit 6c23635)

Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
  • Loading branch information
3 people committed Oct 21, 2023
1 parent cf28c61 commit cf77739
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 9 deletions.
6 changes: 2 additions & 4 deletions Lib/asyncio/taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,14 @@ def __repr__(self):
async def __aenter__(self):
if self._entered:
raise RuntimeError(
f"TaskGroup {self!r} has been already entered")
self._entered = True

f"TaskGroup {self!r} has already been entered")
if self._loop is None:
self._loop = events.get_running_loop()

self._parent_task = tasks.current_task(self._loop)
if self._parent_task is None:
raise RuntimeError(
f'TaskGroup {self!r} cannot determine the parent task')
self._entered = True

return self

Expand Down
12 changes: 8 additions & 4 deletions Lib/asyncio/timeouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def when(self) -> Optional[float]:

def reschedule(self, when: Optional[float]) -> None:
"""Reschedule the timeout."""
assert self._state is not _State.CREATED
if self._state is not _State.ENTERED:
if self._state is _State.CREATED:
raise RuntimeError("Timeout has not been entered")
raise RuntimeError(
f"Cannot change state of {self._state.value} Timeout",
)
Expand Down Expand Up @@ -82,11 +83,14 @@ def __repr__(self) -> str:
return f"<Timeout [{self._state.value}]{info_str}>"

async def __aenter__(self) -> "Timeout":
if self._state is not _State.CREATED:
raise RuntimeError("Timeout has already been entered")
task = tasks.current_task()
if task is None:
raise RuntimeError("Timeout should be used inside a task")
self._state = _State.ENTERED
self._task = tasks.current_task()
self._task = task
self._cancelling = self._task.cancelling()
if self._task is None:
raise RuntimeError("Timeout should be used inside a task")
self.reschedule(self._when)
return self

Expand Down
45 changes: 45 additions & 0 deletions Lib/test/test_asyncio/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from asyncio import taskgroups
import unittest

from test.test_asyncio.utils import await_without_task


# To prevent a warning "test altered the execution environment"
def tearDownModule():
Expand Down Expand Up @@ -779,6 +781,49 @@ async def main():

await asyncio.create_task(main())

async def test_taskgroup_already_entered(self):
tg = taskgroups.TaskGroup()
async with tg:
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with tg:
pass

async def test_taskgroup_double_enter(self):
tg = taskgroups.TaskGroup()
async with tg:
pass
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with tg:
pass

async def test_taskgroup_finished(self):
tg = taskgroups.TaskGroup()
async with tg:
pass
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "is finished"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro

async def test_taskgroup_not_entered(self):
tg = taskgroups.TaskGroup()
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro

async def test_taskgroup_without_parent_task(self):
tg = taskgroups.TaskGroup()
with self.assertRaisesRegex(RuntimeError, "parent task"):
await await_without_task(tg.__aenter__())
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro


if __name__ == "__main__":
unittest.main()
48 changes: 47 additions & 1 deletion Lib/test/test_asyncio/test_timeouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import asyncio
from asyncio import tasks

from test.test_asyncio.utils import await_without_task


def tearDownModule():
asyncio.set_event_loop_policy(None)


class TimeoutTests(unittest.IsolatedAsyncioTestCase):

async def test_timeout_basic(self):
Expand Down Expand Up @@ -258,6 +259,51 @@ async def test_timeout_exception_cause (self):
cause = exc.exception.__cause__
assert isinstance(cause, asyncio.CancelledError)

async def test_timeout_already_entered(self):
async with asyncio.timeout(0.01) as cm:
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with cm:
pass

async def test_timeout_double_enter(self):
async with asyncio.timeout(0.01) as cm:
pass
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with cm:
pass

async def test_timeout_finished(self):
async with asyncio.timeout(0.01) as cm:
pass
with self.assertRaisesRegex(RuntimeError, "finished"):
cm.reschedule(0.02)

async def test_timeout_expired(self):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01) as cm:
await asyncio.sleep(1)
with self.assertRaisesRegex(RuntimeError, "expired"):
cm.reschedule(0.02)

async def test_timeout_expiring(self):
async with asyncio.timeout(0.01) as cm:
with self.assertRaises(asyncio.CancelledError):
await asyncio.sleep(1)
with self.assertRaisesRegex(RuntimeError, "expiring"):
cm.reschedule(0.02)

async def test_timeout_not_entered(self):
cm = asyncio.timeout(0.01)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
cm.reschedule(0.02)

async def test_timeout_without_task(self):
cm = asyncio.timeout(0.01)
with self.assertRaisesRegex(RuntimeError, "task"):
await await_without_task(cm.__aenter__())
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
cm.reschedule(0.02)


if __name__ == '__main__':
unittest.main()
15 changes: 15 additions & 0 deletions Lib/test/test_asyncio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,18 @@ def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
sock.family = family
sock.gettimeout.return_value = 0.0
return sock


async def await_without_task(coro):
exc = None
def func():
try:
for _ in coro.__await__():
pass
except BaseException as err:
nonlocal exc
exc = err
asyncio.get_running_loop().call_soon(func)
await asyncio.sleep(0)
if exc is not None:
raise exc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fix invalid state handling in :class:`asyncio.TaskGroup` and
:class:`asyncio.Timeout`. They now raise proper RuntimeError if they are
improperly used and are left in consistent state after this.

0 comments on commit cf77739

Please sign in to comment.