From be27f7926f0da9b4129d98cf628eea99dc9e3411 Mon Sep 17 00:00:00 2001 From: "patchback[bot]" <45432694+patchback[bot]@users.noreply.github.com> Date: Sun, 1 Sep 2024 22:18:00 +0100 Subject: [PATCH] [PR #8966/f569894c backport][3.11] Fix auth reset logic during redirects to different origin when _base_url set (#8976) **This is a backport of PR #8966 as merged into master (f569894caa7cfbc2ec03fb5eed6021b9899dc4b4).** --------- Co-authored-by: Maxim Zemskov Co-authored-by: Sam Bull --- CHANGES/8966.feature.rst | 1 + aiohttp/client.py | 5 +- docs/client_reference.rst | 8 +- tests/test_client_functional.py | 132 ++++++++++++++++++++++++++++++++ 4 files changed, 142 insertions(+), 4 deletions(-) create mode 100644 CHANGES/8966.feature.rst diff --git a/CHANGES/8966.feature.rst b/CHANGES/8966.feature.rst new file mode 100644 index 00000000000..ab1dc45b60e --- /dev/null +++ b/CHANGES/8966.feature.rst @@ -0,0 +1 @@ +Updated ClientSession's auth logic to include default auth only if the request URL's origin matches _base_url; otherwise, the auth will not be included -- by :user:`MaximZemskov` diff --git a/aiohttp/client.py b/aiohttp/client.py index f9e3a5c5f65..f3c60d31f08 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -597,7 +597,10 @@ async def _request( if auth is None: auth = auth_from_url - if auth is None: + + if auth is None and ( + not self._base_url or self._base_url.origin() == url.origin() + ): auth = self._default_auth # It would be confusing if we support explicit # Authorization header with auth argument diff --git a/docs/client_reference.rst b/docs/client_reference.rst index f64df336755..afad40e2d83 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -100,9 +100,11 @@ The client session supports the context manager protocol for self closing. :param aiohttp.BasicAuth auth: an object that represents HTTP Basic Authorization (optional). It will be included - with any request to any origin and will not be - removed, event during redirect to a different - origin. + with any request. However, if the + ``_base_url`` parameter is set, the request + URL's origin must match the base URL's origin; + otherwise, the default auth will not be + included. :param version: supported HTTP version, ``HTTP 1.1`` by default. diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 1f9173bd3f7..c7c31c739b1 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -2905,6 +2905,138 @@ async def close(self): assert resp.status == 200 +async def test_auth_persist_on_redirect_to_other_host_with_global_auth( + create_server_for_url_and_handler, +) -> None: + url_from = URL("http://host1.com/path1") + url_to = URL("http://host2.com/path2") + + async def srv_from(request: web.Request): + assert request.host == url_from.host + assert request.headers["Authorization"] == "Basic dXNlcjpwYXNz" + raise web.HTTPFound(url_to) + + async def srv_to(request: web.Request) -> web.Response: + assert request.host == url_to.host + assert "Authorization" in request.headers, "Header was dropped" + return web.Response() + + server_from = await create_server_for_url_and_handler(url_from, srv_from) + server_to = await create_server_for_url_and_handler(url_to, srv_to) + + assert ( + url_from.host != url_to.host or server_from.scheme != server_to.scheme + ), "Invalid test case, host or scheme must differ" + + protocol_port_map = { + "http": 80, + "https": 443, + } + etc_hosts = { + (url_from.host, protocol_port_map[server_from.scheme]): server_from, + (url_to.host, protocol_port_map[server_to.scheme]): server_to, + } + + class FakeResolver(AbstractResolver): + async def resolve( + self, + host: str, + port: int = 0, + family: socket.AddressFamily = socket.AF_INET, + ): + server = etc_hosts[(host, port)] + assert server.port is not None + + return [ + { + "hostname": host, + "host": server.host, + "port": server.port, + "family": socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + ] + + async def close(self) -> None: + """Dummy""" + + connector = aiohttp.TCPConnector(resolver=FakeResolver(), ssl=False) + + async with aiohttp.ClientSession( + connector=connector, auth=aiohttp.BasicAuth("user", "pass") + ) as client: + resp = await client.get(url_from) + assert resp.status == 200 + + +async def test_drop_auth_on_redirect_to_other_host_with_global_auth_and_base_url( + create_server_for_url_and_handler, +) -> None: + url_from = URL("http://host1.com/path1") + url_to = URL("http://host2.com/path2") + + async def srv_from(request: web.Request): + assert request.host == url_from.host + assert request.headers["Authorization"] == "Basic dXNlcjpwYXNz" + raise web.HTTPFound(url_to) + + async def srv_to(request: web.Request) -> web.Response: + assert request.host == url_to.host + assert "Authorization" not in request.headers, "Header was not dropped" + return web.Response() + + server_from = await create_server_for_url_and_handler(url_from, srv_from) + server_to = await create_server_for_url_and_handler(url_to, srv_to) + + assert ( + url_from.host != url_to.host or server_from.scheme != server_to.scheme + ), "Invalid test case, host or scheme must differ" + + protocol_port_map = { + "http": 80, + "https": 443, + } + etc_hosts = { + (url_from.host, protocol_port_map[server_from.scheme]): server_from, + (url_to.host, protocol_port_map[server_to.scheme]): server_to, + } + + class FakeResolver(AbstractResolver): + async def resolve( + self, + host: str, + port: int = 0, + family: socket.AddressFamily = socket.AF_INET, + ): + server = etc_hosts[(host, port)] + assert server.port is not None + + return [ + { + "hostname": host, + "host": server.host, + "port": server.port, + "family": socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + ] + + async def close(self) -> None: + """Dummy""" + + connector = aiohttp.TCPConnector(resolver=FakeResolver(), ssl=False) + + async with aiohttp.ClientSession( + connector=connector, + base_url="http://host1.com", + auth=aiohttp.BasicAuth("user", "pass"), + ) as client: + resp = await client.get("/path1") + assert resp.status == 200 + + async def test_async_with_session() -> None: async with aiohttp.ClientSession() as session: pass