Skip to content

Commit

Permalink
[PR #8495/549c95b9 backport][3.10] Shutdown logic: Only wait on handl…
Browse files Browse the repository at this point in the history
…ers (#8530)

Co-authored-by: pre-commit-ci[bot]
Co-authored-by: J. Nick Koston <nick@koston.org>
Co-authored-by: Sam Bull <git@sambull.org>
  • Loading branch information
bdraco and Dreamsorcerer authored Jul 22, 2024
1 parent be2a8bf commit 6f17a67
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 79 deletions.
26 changes: 0 additions & 26 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import warnings
from argparse import ArgumentParser
from collections.abc import Iterable
from contextlib import suppress
from functools import partial
from importlib import import_module
from typing import (
Any,
Expand All @@ -21,7 +19,6 @@
Union,
cast,
)
from weakref import WeakSet

from .abc import AbstractAccessLogger
from .helpers import AppKey as AppKey
Expand Down Expand Up @@ -320,23 +317,6 @@ async def _run_app(
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
) -> None:
async def wait(
starting_tasks: "WeakSet[asyncio.Task[object]]", shutdown_timeout: float
) -> None:
# Wait for pending tasks for a given time limit.
t = asyncio.current_task()
assert t is not None
starting_tasks.add(t)
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(_wait(starting_tasks), timeout=shutdown_timeout)

async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None:
t = asyncio.current_task()
assert t is not None
exclude.add(t)
while tasks := asyncio.all_tasks().difference(exclude):
await asyncio.wait(tasks)

# An internal function to actually do all dirty job for application running
if asyncio.iscoroutine(app):
app = await app
Expand All @@ -355,12 +335,6 @@ async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None:
)

await runner.setup()
# On shutdown we want to avoid waiting on tasks which run forever.
# It's very likely that all tasks which run forever will have been created by
# the time we have completed the application startup (in runner.setup()),
# so we just record all running tasks here and exclude them later.
starting_tasks: "WeakSet[asyncio.Task[object]]" = WeakSet(asyncio.all_tasks())
runner.shutdown_callback = partial(wait, starting_tasks, shutdown_timeout)

sites: List[BaseSite] = []

Expand Down
8 changes: 6 additions & 2 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,12 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
if self._waiter:
self._waiter.cancel()

# wait for handlers
# Wait for graceful disconnection
if self._current_request is not None:
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
async with ceil_timeout(timeout):
await self._current_request.wait_for_disconnection()
# Then cancel handler and wait
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
async with ceil_timeout(timeout):
if self._current_request is not None:
Expand Down Expand Up @@ -445,7 +450,6 @@ async def _handle_request(
start_time: float,
request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]],
) -> Tuple[StreamResponse, bool]:
assert self._request_handler is not None
try:
try:
self._current_request = request
Expand Down
16 changes: 2 additions & 14 deletions aiohttp/web_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import socket
import warnings
from abc import ABC, abstractmethod
from typing import Any, Awaitable, Callable, List, Optional, Set
from typing import Any, List, Optional, Set

from yarl import URL

Expand Down Expand Up @@ -238,14 +238,7 @@ async def start(self) -> None:


class BaseRunner(ABC):
__slots__ = (
"shutdown_callback",
"_handle_signals",
"_kwargs",
"_server",
"_sites",
"_shutdown_timeout",
)
__slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout")

def __init__(
self,
Expand All @@ -254,7 +247,6 @@ def __init__(
shutdown_timeout: float = 60.0,
**kwargs: Any,
) -> None:
self.shutdown_callback: Optional[Callable[[], Awaitable[None]]] = None
self._handle_signals = handle_signals
self._kwargs = kwargs
self._server: Optional[Server] = None
Expand Down Expand Up @@ -312,10 +304,6 @@ async def cleanup(self) -> None:
await asyncio.sleep(0)
self._server.pre_shutdown()
await self.shutdown()

if self.shutdown_callback:
await self.shutdown_callback()

await self._server.shutdown(self._shutdown_timeout)
await self._cleanup_server()

Expand Down
7 changes: 6 additions & 1 deletion aiohttp/web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ def connection_lost(
self, handler: RequestHandler, exc: Optional[BaseException] = None
) -> None:
if handler in self._connections:
del self._connections[handler]
if handler._task_handler:
handler._task_handler.add_done_callback(
lambda f: self._connections.pop(handler, None)
)
else:
del self._connections[handler]

def _make_request(
self,
Expand Down
43 changes: 10 additions & 33 deletions tests/test_run_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import pytest

from aiohttp import ClientConnectorError, ClientSession, WSCloseCode, web
from aiohttp import ClientConnectorError, ClientSession, ClientTimeout, WSCloseCode, web
from aiohttp.test_utils import make_mocked_coro
from aiohttp.web_runner import BaseRunner

Expand Down Expand Up @@ -920,8 +920,12 @@ async def test() -> None:
async with ClientSession() as sess:
for _ in range(5): # pragma: no cover
try:
async with sess.get(f"http://localhost:{port}/"):
pass
with pytest.raises(asyncio.TimeoutError):
async with sess.get(
f"http://localhost:{port}/",
timeout=ClientTimeout(total=0.1),
):
pass
except ClientConnectorError:
await asyncio.sleep(0.5)
else:
Expand All @@ -941,6 +945,7 @@ async def run_test(app: web.Application) -> None:
async def handler(request: web.Request) -> web.Response:
nonlocal t
t = asyncio.create_task(task())
await t
return web.Response(text="FOO")

t = test_task = None
Expand All @@ -953,7 +958,7 @@ async def handler(request: web.Request) -> web.Response:
assert test_task.exception() is None
return t

def test_shutdown_wait_for_task(
def test_shutdown_wait_for_handler(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
Expand All @@ -970,7 +975,7 @@ async def task():
assert t.done()
assert not t.cancelled()

def test_shutdown_timeout_task(
def test_shutdown_timeout_handler(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
Expand All @@ -987,34 +992,6 @@ async def task():
assert t.done()
assert t.cancelled()

def test_shutdown_wait_for_spawned_task(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False
finished_sub = False
sub_t = None

async def sub_task():
nonlocal finished_sub
await asyncio.sleep(1.5)
finished_sub = True

async def task():
nonlocal finished, sub_t
await asyncio.sleep(0.5)
sub_t = asyncio.create_task(sub_task())
finished = True

t = self.run_app(port, 3, task)

assert finished is True
assert t.done()
assert not t.cancelled()
assert finished_sub is True
assert sub_t.done()
assert not sub_t.cancelled()

def test_shutdown_timeout_not_reached(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
Expand Down
8 changes: 5 additions & 3 deletions tests/test_web_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,21 @@ async def test_connections() -> None:
manager = web.Server(serve)
assert manager.connections == []

handler = object()
handler = mock.Mock(spec_set=web.RequestHandler)
handler._task_handler = None
transport = object()
manager.connection_made(handler, transport) # type: ignore[arg-type]
assert manager.connections == [handler]

manager.connection_lost(handler, None) # type: ignore[arg-type]
manager.connection_lost(handler, None)
assert manager.connections == []


async def test_shutdown_no_timeout() -> None:
manager = web.Server(serve)

handler = mock.Mock()
handler = mock.Mock(spec_set=web.RequestHandler)
handler._task_handler = None
handler.shutdown = make_mocked_coro(mock.Mock())
transport = mock.Mock()
manager.connection_made(handler, transport)
Expand Down

0 comments on commit 6f17a67

Please sign in to comment.