Skip to content

Commit

Permalink
fix: Fixes a bug that caused module-scoped async fixtures to fail whe…
Browse files Browse the repository at this point in the history
…n reused in other modules.

Async fixture synchronizers now choose the event loop for the async fixutre at runtime rather than relying on collection-time information.

This fixes #862.
  • Loading branch information
cstruct authored and seifertm committed Aug 9, 2024
1 parent f45aa18 commit 6f33fed
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 49 deletions.
92 changes: 43 additions & 49 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,8 @@ def _preprocess_async_fixtures(
or default_loop_scope
or fixturedef.scope
)
if scope == "function":
event_loop_fixture_id: Optional[str] = "event_loop"
else:
event_loop_node = _retrieve_scope_root(collector, scope)
event_loop_fixture_id = event_loop_node.stash.get(
# Type ignored because of non-optimal mypy inference.
_event_loop_fixture_id, # type: ignore[arg-type]
None,
)
if scope == "function" and "event_loop" not in fixturedef.argnames:
fixturedef.argnames += ("event_loop",)
_make_asyncio_fixture_function(func, scope)
function_signature = inspect.signature(func)
if "event_loop" in function_signature.parameters:
Expand All @@ -271,58 +264,35 @@ def _preprocess_async_fixtures(
f"instead."
)
)
assert event_loop_fixture_id
_inject_fixture_argnames(
fixturedef,
event_loop_fixture_id,
)
_synchronize_async_fixture(
fixturedef,
event_loop_fixture_id,
)
if "request" not in fixturedef.argnames:
fixturedef.argnames += ("request",)
_synchronize_async_fixture(fixturedef)
assert _is_asyncio_fixture_function(fixturedef.func)
processed_fixturedefs.add(fixturedef)


def _inject_fixture_argnames(
fixturedef: FixtureDef, event_loop_fixture_id: str
) -> None:
"""
Ensures that `request` and `event_loop` are arguments of the specified fixture.
"""
to_add = []
for name in ("request", event_loop_fixture_id):
if name not in fixturedef.argnames:
to_add.append(name)
if to_add:
fixturedef.argnames += tuple(to_add)


def _synchronize_async_fixture(
fixturedef: FixtureDef, event_loop_fixture_id: str
) -> None:
def _synchronize_async_fixture(fixturedef: FixtureDef) -> None:
"""
Wraps the fixture function of an async fixture in a synchronous function.
"""
if inspect.isasyncgenfunction(fixturedef.func):
_wrap_asyncgen_fixture(fixturedef, event_loop_fixture_id)
_wrap_asyncgen_fixture(fixturedef)
elif inspect.iscoroutinefunction(fixturedef.func):
_wrap_async_fixture(fixturedef, event_loop_fixture_id)
_wrap_async_fixture(fixturedef)


def _add_kwargs(
func: Callable[..., Any],
kwargs: Dict[str, Any],
event_loop_fixture_id: str,
event_loop: asyncio.AbstractEventLoop,
request: FixtureRequest,
) -> Dict[str, Any]:
sig = inspect.signature(func)
ret = kwargs.copy()
if "request" in sig.parameters:
ret["request"] = request
if event_loop_fixture_id in sig.parameters:
ret[event_loop_fixture_id] = event_loop
if "event_loop" in sig.parameters:
ret["event_loop"] = event_loop
return ret


Expand All @@ -345,17 +315,19 @@ def _perhaps_rebind_fixture_func(
return func


def _wrap_asyncgen_fixture(fixturedef: FixtureDef, event_loop_fixture_id: str) -> None:
def _wrap_asyncgen_fixture(fixturedef: FixtureDef) -> None:
fixture = fixturedef.func

@functools.wraps(fixture)
def _asyncgen_fixture_wrapper(request: FixtureRequest, **kwargs: Any):
unittest = fixturedef.unittest if hasattr(fixturedef, "unittest") else False
func = _perhaps_rebind_fixture_func(fixture, request.instance, unittest)
event_loop = kwargs.pop(event_loop_fixture_id)
gen_obj = func(
**_add_kwargs(func, kwargs, event_loop_fixture_id, event_loop, request)
event_loop_fixture_id = _get_event_loop_fixture_id_for_async_fixture(
request, func
)
event_loop = request.getfixturevalue(event_loop_fixture_id)
kwargs.pop(event_loop_fixture_id, None)
gen_obj = func(**_add_kwargs(func, kwargs, event_loop, request))

async def setup():
res = await gen_obj.__anext__()
Expand Down Expand Up @@ -383,26 +355,48 @@ async def async_finalizer() -> None:
fixturedef.func = _asyncgen_fixture_wrapper


def _wrap_async_fixture(fixturedef: FixtureDef, event_loop_fixture_id: str) -> None:
def _wrap_async_fixture(fixturedef: FixtureDef) -> None:
fixture = fixturedef.func

@functools.wraps(fixture)
def _async_fixture_wrapper(request: FixtureRequest, **kwargs: Any):
unittest = False if pytest.version_tuple >= (8, 2) else fixturedef.unittest
func = _perhaps_rebind_fixture_func(fixture, request.instance, unittest)
event_loop = kwargs.pop(event_loop_fixture_id)
event_loop_fixture_id = _get_event_loop_fixture_id_for_async_fixture(
request, func
)
event_loop = request.getfixturevalue(event_loop_fixture_id)
kwargs.pop(event_loop_fixture_id, None)

async def setup():
res = await func(
**_add_kwargs(func, kwargs, event_loop_fixture_id, event_loop, request)
)
res = await func(**_add_kwargs(func, kwargs, event_loop, request))
return res

return event_loop.run_until_complete(setup())

fixturedef.func = _async_fixture_wrapper


def _get_event_loop_fixture_id_for_async_fixture(
request: FixtureRequest, func: Any
) -> str:
default_loop_scope = request.config.getini("asyncio_default_fixture_loop_scope")
loop_scope = (
getattr(func, "_loop_scope", None) or default_loop_scope or request.scope
)
if loop_scope == "function":
event_loop_fixture_id = "event_loop"
else:
event_loop_node = _retrieve_scope_root(request._pyfuncitem, loop_scope)
event_loop_fixture_id = event_loop_node.stash.get(
# Type ignored because of non-optimal mypy inference.
_event_loop_fixture_id, # type: ignore[arg-type]
"",
)
assert event_loop_fixture_id
return event_loop_fixture_id


class PytestAsyncioFunction(Function):
"""Base class for all test functions managed by pytest-asyncio."""

Expand Down
35 changes: 35 additions & 0 deletions tests/async_fixtures/test_shared_module_fixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from textwrap import dedent

from pytest import Pytester


def test_asyncio_mark_provides_package_scoped_loop_strict_mode(pytester: Pytester):
pytester.makepyfile(
__init__="",
conftest=dedent(
"""\
import pytest_asyncio
@pytest_asyncio.fixture(scope="module")
async def async_shared_module_fixture():
return True
"""
),
test_module_one=dedent(
"""\
import pytest
@pytest.mark.asyncio
async def test_shared_module_fixture_use_a(async_shared_module_fixture):
assert async_shared_module_fixture is True
"""
),
test_module_two=dedent(
"""\
import pytest
@pytest.mark.asyncio
async def test_shared_module_fixture_use_b(async_shared_module_fixture):
assert async_shared_module_fixture is True
"""
),
)
result = pytester.runpytest("--asyncio-mode=strict")
result.assert_outcomes(passed=2)

0 comments on commit 6f33fed

Please sign in to comment.