diff --git a/aiohttp/web.py b/aiohttp/web.py index e9116507f4e..8708f1fcbec 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -6,8 +6,6 @@ import warnings from argparse import ArgumentParser from collections.abc import Iterable -from contextlib import suppress -from functools import partial from importlib import import_module from typing import ( Any, @@ -21,7 +19,6 @@ Union, cast, ) -from weakref import WeakSet from .abc import AbstractAccessLogger from .helpers import AppKey as AppKey @@ -320,23 +317,6 @@ async def _run_app( reuse_port: Optional[bool] = None, handler_cancellation: bool = False, ) -> None: - async def wait( - starting_tasks: "WeakSet[asyncio.Task[object]]", shutdown_timeout: float - ) -> None: - # Wait for pending tasks for a given time limit. - t = asyncio.current_task() - assert t is not None - starting_tasks.add(t) - with suppress(asyncio.TimeoutError): - await asyncio.wait_for(_wait(starting_tasks), timeout=shutdown_timeout) - - async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None: - t = asyncio.current_task() - assert t is not None - exclude.add(t) - while tasks := asyncio.all_tasks().difference(exclude): - await asyncio.wait(tasks) - # An internal function to actually do all dirty job for application running if asyncio.iscoroutine(app): app = await app @@ -355,12 +335,6 @@ async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None: ) await runner.setup() - # On shutdown we want to avoid waiting on tasks which run forever. - # It's very likely that all tasks which run forever will have been created by - # the time we have completed the application startup (in runner.setup()), - # so we just record all running tasks here and exclude them later. - starting_tasks: "WeakSet[asyncio.Task[object]]" = WeakSet(asyncio.all_tasks()) - runner.shutdown_callback = partial(wait, starting_tasks, shutdown_timeout) sites: List[BaseSite] = [] diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 88df4b31d24..d4ddbba55eb 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -262,7 +262,12 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None: if self._waiter: self._waiter.cancel() - # wait for handlers + # Wait for graceful disconnection + if self._current_request is not None: + with suppress(asyncio.CancelledError, asyncio.TimeoutError): + async with ceil_timeout(timeout): + await self._current_request.wait_for_disconnection() + # Then cancel handler and wait with suppress(asyncio.CancelledError, asyncio.TimeoutError): async with ceil_timeout(timeout): if self._current_request is not None: @@ -445,7 +450,6 @@ async def _handle_request( start_time: float, request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]], ) -> Tuple[StreamResponse, bool]: - assert self._request_handler is not None try: try: self._current_request = request diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index 19a4441658f..2fe229c4e50 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -3,7 +3,7 @@ import socket import warnings from abc import ABC, abstractmethod -from typing import Any, Awaitable, Callable, List, Optional, Set +from typing import Any, List, Optional, Set from yarl import URL @@ -238,14 +238,7 @@ async def start(self) -> None: class BaseRunner(ABC): - __slots__ = ( - "shutdown_callback", - "_handle_signals", - "_kwargs", - "_server", - "_sites", - "_shutdown_timeout", - ) + __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout") def __init__( self, @@ -254,7 +247,6 @@ def __init__( shutdown_timeout: float = 60.0, **kwargs: Any, ) -> None: - self.shutdown_callback: Optional[Callable[[], Awaitable[None]]] = None self._handle_signals = handle_signals self._kwargs = kwargs self._server: Optional[Server] = None @@ -312,10 +304,6 @@ async def cleanup(self) -> None: await asyncio.sleep(0) self._server.pre_shutdown() await self.shutdown() - - if self.shutdown_callback: - await self.shutdown_callback() - await self._server.shutdown(self._shutdown_timeout) await self._cleanup_server() diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index f6bbdb89a77..ffc198d5780 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -43,7 +43,12 @@ def connection_lost( self, handler: RequestHandler, exc: Optional[BaseException] = None ) -> None: if handler in self._connections: - del self._connections[handler] + if handler._task_handler: + handler._task_handler.add_done_callback( + lambda f: self._connections.pop(handler, None) + ) + else: + del self._connections[handler] def _make_request( self, diff --git a/tests/test_run_app.py b/tests/test_run_app.py index 5696928b219..eb69d620ced 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -15,7 +15,7 @@ import pytest -from aiohttp import ClientConnectorError, ClientSession, WSCloseCode, web +from aiohttp import ClientConnectorError, ClientSession, ClientTimeout, WSCloseCode, web from aiohttp.test_utils import make_mocked_coro from aiohttp.web_runner import BaseRunner @@ -920,8 +920,12 @@ async def test() -> None: async with ClientSession() as sess: for _ in range(5): # pragma: no cover try: - async with sess.get(f"http://localhost:{port}/"): - pass + with pytest.raises(asyncio.TimeoutError): + async with sess.get( + f"http://localhost:{port}/", + timeout=ClientTimeout(total=0.1), + ): + pass except ClientConnectorError: await asyncio.sleep(0.5) else: @@ -941,6 +945,7 @@ async def run_test(app: web.Application) -> None: async def handler(request: web.Request) -> web.Response: nonlocal t t = asyncio.create_task(task()) + await t return web.Response(text="FOO") t = test_task = None @@ -953,7 +958,7 @@ async def handler(request: web.Request) -> web.Response: assert test_task.exception() is None return t - def test_shutdown_wait_for_task( + def test_shutdown_wait_for_handler( self, aiohttp_unused_port: Callable[[], int] ) -> None: port = aiohttp_unused_port() @@ -970,7 +975,7 @@ async def task(): assert t.done() assert not t.cancelled() - def test_shutdown_timeout_task( + def test_shutdown_timeout_handler( self, aiohttp_unused_port: Callable[[], int] ) -> None: port = aiohttp_unused_port() @@ -987,34 +992,6 @@ async def task(): assert t.done() assert t.cancelled() - def test_shutdown_wait_for_spawned_task( - self, aiohttp_unused_port: Callable[[], int] - ) -> None: - port = aiohttp_unused_port() - finished = False - finished_sub = False - sub_t = None - - async def sub_task(): - nonlocal finished_sub - await asyncio.sleep(1.5) - finished_sub = True - - async def task(): - nonlocal finished, sub_t - await asyncio.sleep(0.5) - sub_t = asyncio.create_task(sub_task()) - finished = True - - t = self.run_app(port, 3, task) - - assert finished is True - assert t.done() - assert not t.cancelled() - assert finished_sub is True - assert sub_t.done() - assert not sub_t.cancelled() - def test_shutdown_timeout_not_reached( self, aiohttp_unused_port: Callable[[], int] ) -> None: diff --git a/tests/test_web_request_handler.py b/tests/test_web_request_handler.py index 06f99be76c0..4837cab030e 100644 --- a/tests/test_web_request_handler.py +++ b/tests/test_web_request_handler.py @@ -22,19 +22,21 @@ async def test_connections() -> None: manager = web.Server(serve) assert manager.connections == [] - handler = object() + handler = mock.Mock(spec_set=web.RequestHandler) + handler._task_handler = None transport = object() manager.connection_made(handler, transport) # type: ignore[arg-type] assert manager.connections == [handler] - manager.connection_lost(handler, None) # type: ignore[arg-type] + manager.connection_lost(handler, None) assert manager.connections == [] async def test_shutdown_no_timeout() -> None: manager = web.Server(serve) - handler = mock.Mock() + handler = mock.Mock(spec_set=web.RequestHandler) + handler._task_handler = None handler.shutdown = make_mocked_coro(mock.Mock()) transport = mock.Mock() manager.connection_made(handler, transport)