diff --git a/docs/integrations/index.rst b/docs/integrations/index.rst index 6fdc377a..c4eee1c4 100644 --- a/docs/integrations/index.rst +++ b/docs/integrations/index.rst @@ -113,7 +113,7 @@ With some frameworks we provide an option to inject dependencies in handlers wit setup_dishka(container, app) -* For **FasStream** (**0.5.0** version and higher) you need to provide ``auto_inject=True`` when calling ``setup_dishka``. It is important here to call it before registering any subscribers or router include: +* For **FasStream** (**0.5.0** version and higher) you need to provide ``auto_inject=True`` when calling ``setup_dishka``. E.g: .. code-block:: python diff --git a/src/dishka/integrations/faststream.py b/src/dishka/integrations/faststream.py index 02cfd842..16388871 100644 --- a/src/dishka/integrations/faststream.py +++ b/src/dishka/integrations/faststream.py @@ -108,6 +108,18 @@ def setup_dishka( *app.broker._middlewares, # noqa: SLF001 ) + for subscriber in app.broker._subscribers.values(): # noqa: SLF001 + subscriber._broker_middlewares = ( # noqa: SLF001 + DishkaMiddleware(container), + *subscriber._broker_middlewares, # noqa: SLF001 + ) + + for publisher in app.broker._publishers.values(): # noqa: SLF001 + publisher._broker_middlewares = ( # noqa: SLF001 + DishkaMiddleware(container), + *publisher._broker_middlewares, # noqa: SLF001 + ) + if auto_inject: app.broker._call_decorators = ( # noqa: SLF001 inject, diff --git a/tests/integrations/faststream/test_faststream.py b/tests/integrations/faststream/test_faststream.py index f32f103b..b8cb7395 100644 --- a/tests/integrations/faststream/test_faststream.py +++ b/tests/integrations/faststream/test_faststream.py @@ -8,6 +8,7 @@ from dishka import make_async_container from dishka.integrations.faststream import ( + FASTSTREAM_OLD_MIDDLEWARES, FromDishka, inject, setup_dishka, @@ -63,7 +64,6 @@ async def get_with_request( return "passed" - @pytest.mark.asyncio async def test_request_dependency(app_provider: AppProvider): async with dishka_app(get_with_request, app_provider) as client: @@ -71,3 +71,49 @@ async def test_request_dependency(app_provider: AppProvider): app_provider.mock.assert_called_with(REQUEST_DEP_VALUE) app_provider.request_released.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.skipif( + FASTSTREAM_OLD_MIDDLEWARES, + reason="Requires FastStream 0.5.0+", +) +async def test_autoinject_before_subscriber(app_provider: AppProvider): + broker = NatsBroker() + app = FastStream(broker) + + container = make_async_container(app_provider) + setup_dishka(container, app=app, auto_inject=True) + + broker.subscriber("test")(get_with_request) + + async with TestNatsBroker(broker) as br: + assert await br.publish("", "test", rpc=True) == "passed" + + app_provider.mock.assert_called_with(REQUEST_DEP_VALUE) + app_provider.request_released.assert_called_once() + + await container.close() + + +@pytest.mark.asyncio +@pytest.mark.skipif( + FASTSTREAM_OLD_MIDDLEWARES, + reason="Requires FastStream 0.5.0+", +) +async def test_autoinject_after_subscriber(app_provider: AppProvider): + broker = NatsBroker() + app = FastStream(broker) + + broker.subscriber("test")(get_with_request) + + container = make_async_container(app_provider) + setup_dishka(container, app=app, auto_inject=True) + + async with TestNatsBroker(broker) as br: + assert await br.publish("", "test", rpc=True) == "passed" + + app_provider.mock.assert_called_with(REQUEST_DEP_VALUE) + app_provider.request_released.assert_called_once() + + await container.close()