Skip to content

Commit

Permalink
Preprocess async fixtures for each event loop
Browse files Browse the repository at this point in the history
Caching of fixture preprocessing is now also keyed by event loop id,
sync fixtures can be processed if they are wrapping a async fixture,
each async fixture has a mapping from root scope names to fixture id
that is now used to dynamically get the event loop fixture.

This fixes pytest-dev#862.
  • Loading branch information
cstruct committed Aug 5, 2024
1 parent cbb39a0 commit ee75d65
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 31 deletions.
92 changes: 61 additions & 31 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def _preprocess_async_fixtures(
for fixtures in fixturemanager._arg2fixturedefs.values():
for fixturedef in fixtures:
func = fixturedef.func
if fixturedef in processed_fixturedefs or not _is_coroutine_or_asyncgen(
func
if not _is_coroutine_or_asyncgen(func) and not getattr(
func, "_async_fixture", False
):
continue
if not _is_asyncio_fixture_function(func) and asyncio_mode == Mode.STRICT:
Expand All @@ -252,14 +252,21 @@ def _preprocess_async_fixtures(
or fixturedef.scope
)
if scope == "function":
event_loop_fixture_name = "function"
event_loop_fixture_id: Optional[str] = "event_loop"
else:
event_loop_node = _retrieve_scope_root(collector, scope)
try:
event_loop_node = _retrieve_scope_root(collector, scope)
except Exception:
continue
event_loop_fixture_name = event_loop_node.name
event_loop_fixture_id = event_loop_node.stash.get(
# Type ignored because of non-optimal mypy inference.
_event_loop_fixture_id, # type: ignore[arg-type]
None,
)
if (fixturedef, event_loop_fixture_id) in processed_fixturedefs:
continue
_make_asyncio_fixture_function(func, scope)
function_signature = inspect.signature(func)
if "event_loop" in function_signature.parameters:
Expand All @@ -272,42 +279,33 @@ def _preprocess_async_fixtures(
)
)
assert event_loop_fixture_id
_inject_fixture_argnames(
fixturedef,
event_loop_fixture_id,
)
if "request" not in fixturedef.argnames:
fixturedef.argnames += ("request",)
_synchronize_async_fixture(
fixturedef,
event_loop_fixture_name,
event_loop_fixture_id,
)
assert _is_asyncio_fixture_function(fixturedef.func)
processed_fixturedefs.add(fixturedef)


def _inject_fixture_argnames(
fixturedef: FixtureDef, event_loop_fixture_id: str
) -> None:
"""
Ensures that `request` and `event_loop` are arguments of the specified fixture.
"""
to_add = []
for name in ("request", event_loop_fixture_id):
if name not in fixturedef.argnames:
to_add.append(name)
if to_add:
fixturedef.argnames += tuple(to_add)
processed_fixturedefs.add((fixturedef, event_loop_fixture_id))


def _synchronize_async_fixture(
fixturedef: FixtureDef, event_loop_fixture_id: str
fixturedef: FixtureDef, event_loop_fixture_name: str, event_loop_fixture_id: str
) -> None:
"""
Wraps the fixture function of an async fixture in a synchronous function.
"""
if inspect.isasyncgenfunction(fixturedef.func):
_wrap_asyncgen_fixture(fixturedef, event_loop_fixture_id)
elif inspect.iscoroutinefunction(fixturedef.func):
_wrap_async_fixture(fixturedef, event_loop_fixture_id)
if inspect.isasyncgenfunction(fixturedef.func) or getattr(
fixturedef.func, "_async_fixture", False
):
_wrap_asyncgen_fixture(
fixturedef, event_loop_fixture_name, event_loop_fixture_id
)
elif inspect.iscoroutinefunction(fixturedef.func) or getattr(
fixturedef.func, "_async_fixture", False
):
_wrap_async_fixture(fixturedef, event_loop_fixture_name, event_loop_fixture_id)


def _add_kwargs(
Expand Down Expand Up @@ -345,14 +343,27 @@ def _perhaps_rebind_fixture_func(
return func


def _wrap_asyncgen_fixture(fixturedef: FixtureDef, event_loop_fixture_id: str) -> None:
def _wrap_asyncgen_fixture(
fixturedef: FixtureDef, event_loop_fixture_name: str, event_loop_fixture_id: str
) -> None:
fixture = fixturedef.func

event_loop_id_mapping = getattr(fixture, "_event_loop_id_mapping", {})
event_loop_id_mapping[event_loop_fixture_name] = event_loop_fixture_id
event_loop_id_mapping["function"] = event_loop_fixture_id

if getattr(fixture, "_async_fixture", False):
return fixture

@functools.wraps(fixture)
def _asyncgen_fixture_wrapper(request: FixtureRequest, **kwargs: Any):
unittest = fixturedef.unittest if hasattr(fixturedef, "unittest") else False
func = _perhaps_rebind_fixture_func(fixture, request.instance, unittest)
event_loop = kwargs.pop(event_loop_fixture_id)
event_loop_fixture_id = event_loop_id_mapping.get(
request.node.name, event_loop_id_mapping["function"]
)
event_loop = request.getfixturevalue(event_loop_fixture_id)
kwargs.pop(event_loop_fixture_id, None)
gen_obj = func(
**_add_kwargs(func, kwargs, event_loop_fixture_id, event_loop, request)
)
Expand Down Expand Up @@ -380,17 +391,33 @@ async def async_finalizer() -> None:
request.addfinalizer(finalizer)
return result

setattr(_asyncgen_fixture_wrapper, "_event_loop_id_mapping", event_loop_id_mapping)
setattr(_asyncgen_fixture_wrapper, "_async_fixture", True)

fixturedef.func = _asyncgen_fixture_wrapper


def _wrap_async_fixture(fixturedef: FixtureDef, event_loop_fixture_id: str) -> None:
def _wrap_async_fixture(
fixturedef: FixtureDef, event_loop_fixture_name: str, event_loop_fixture_id: str
) -> None:
fixture = fixturedef.func

event_loop_id_mapping = getattr(fixture, "_event_loop_id_mapping", {})
event_loop_id_mapping[event_loop_fixture_name] = event_loop_fixture_id
event_loop_id_mapping["function"] = event_loop_fixture_id

if getattr(fixture, "_async_fixture", False):
return fixture

@functools.wraps(fixture)
def _async_fixture_wrapper(request: FixtureRequest, **kwargs: Any):
unittest = False if pytest.version_tuple >= (8, 2) else fixturedef.unittest
func = _perhaps_rebind_fixture_func(fixture, request.instance, unittest)
event_loop = kwargs.pop(event_loop_fixture_id)
event_loop_fixture_id = event_loop_id_mapping.get(
request.node.name, event_loop_id_mapping["function"]
)
event_loop = request.getfixturevalue(event_loop_fixture_id)
kwargs.pop(event_loop_fixture_id, None)

async def setup():
res = await func(
Expand All @@ -400,6 +427,9 @@ async def setup():

return event_loop.run_until_complete(setup())

setattr(_async_fixture_wrapper, "_event_loop_id_mapping", event_loop_id_mapping)
setattr(_async_fixture_wrapper, "_async_fixture", True)

fixturedef.func = _async_fixture_wrapper


Expand Down
35 changes: 35 additions & 0 deletions tests/async_fixtures/test_shared_module_fixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from textwrap import dedent

from pytest import Pytester


def test_asyncio_mark_provides_package_scoped_loop_strict_mode(pytester: Pytester):
pytester.makepyfile(
__init__="",
conftest=dedent(
"""\
import pytest_asyncio
@pytest_asyncio.fixture(scope="module")
async def async_shared_module_fixture():
return True
"""
),
test_module_one=dedent(
"""\
import pytest
@pytest.mark.asyncio
async def test_shared_module_fixture_use_a(async_shared_module_fixture):
assert async_shared_module_fixture is True
"""
),
test_module_two=dedent(
"""\
import pytest
@pytest.mark.asyncio
async def test_shared_module_fixture_use_b(async_shared_module_fixture):
assert async_shared_module_fixture is True
"""
),
)
result = pytester.runpytest("--asyncio-mode=strict")
result.assert_outcomes(passed=2)

0 comments on commit ee75d65

Please sign in to comment.