Skip to content

Commit

Permalink
compile factory calls
Browse files Browse the repository at this point in the history
  • Loading branch information
Tishka17 committed Mar 24, 2024
1 parent 69aa5a6 commit 84885fb
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 101 deletions.
56 changes: 9 additions & 47 deletions src/dishka/async_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from dishka.entities.component import DEFAULT_COMPONENT, Component
from dishka.entities.key import DependencyKey
from dishka.entities.scope import BaseScope, Scope
from .dependency_source import Factory, FactoryType
from .dependency_source import FactoryType
from .exceptions import (
ExitError,
NoContextValueError,
NoFactoryError,
UnsupportedFactoryError,
)
from .provider import BaseProvider
from .registry import Registry, RegistryBuilder
Expand Down Expand Up @@ -84,47 +82,6 @@ def __call__(
raise ValueError("No child scopes found")
return AsyncContextWrapper(self._create_child(context, lock_factory))

async def _get_from_self(
self, factory: Factory, key: DependencyKey,
) -> T:
try:
sub_dependencies = [
await self._get_unlocked(dependency)
for dependency in factory.dependencies
]
except NoFactoryError as e:
e.add_path(factory)
raise

if factory.type is FactoryType.GENERATOR:
generator = factory.source(*sub_dependencies)
self._exits.append(Exit(factory.type, generator))
solved = next(generator)
elif factory.type is FactoryType.FACTORY:
solved = factory.source(*sub_dependencies)
elif factory.type is FactoryType.ASYNC_GENERATOR:
generator = factory.source(*sub_dependencies)
self._exits.append(Exit(factory.type, generator))
solved = await anext(generator)
elif factory.type is FactoryType.ASYNC_FACTORY:
solved = await factory.source(*sub_dependencies)
elif factory.type is FactoryType.VALUE:
solved = factory.source
elif factory.type is FactoryType.ALIAS:
solved = sub_dependencies[0]
elif factory.type is FactoryType.CONTEXT:
raise NoContextValueError(
f"Value for type {factory.provides.type_hint} is not found "
f"in container context with scope={factory.scope}",
)
else:
raise UnsupportedFactoryError(
f"Unsupported factory type {factory.type}.",
)
if factory.cache:
self.context[key] = solved
return solved

async def get(
self,
dependency_type: type[T],
Expand All @@ -140,14 +97,19 @@ async def get(
async def _get_unlocked(self, key: DependencyKey) -> Any:
if key in self.context:
return self.context[key]
factory = self.registry.get_factory(key)
if not factory:
compiled = self.registry.get_compiled_async(key)
if not compiled:
if not self.parent_container:
raise NoFactoryError(key)
return await self.parent_container.get(
key.type_hint, key.component,
)
return await self._get_from_self(factory, key)
try:
return await compiled(self._get_unlocked, self._exits, self.context)
except NoFactoryError as e:
e.add_path(self.registry.get_factory(key))
raise


async def close(self):
errors = []
Expand Down
60 changes: 9 additions & 51 deletions src/dishka/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from dishka.entities.component import DEFAULT_COMPONENT, Component
from dishka.entities.key import DependencyKey
from dishka.entities.scope import BaseScope, Scope
from .dependency_source import Factory, FactoryType
from .dependency_source import FactoryType
from .exceptions import (
ExitError,
NoContextValueError,
NoFactoryError,
UnsupportedFactoryError,
)
from .provider import BaseProvider
from .registry import Registry, RegistryBuilder
Expand Down Expand Up @@ -83,51 +81,6 @@ def __call__(
raise ValueError("No child scopes found")
return ContextWrapper(self._create_child(context, lock_factory))

def _get_from_self(
self, factory: Factory, key: DependencyKey,
) -> T:
try:
sub_dependencies = [
self._get_unlocked(dependency)
for dependency in factory.dependencies
]
except NoFactoryError as e:
e.add_path(factory)
raise

if factory.type is FactoryType.GENERATOR:
generator = factory.source(*sub_dependencies)
self._exits.append(Exit(factory.type, generator))
solved = next(generator)
elif factory.type is FactoryType.FACTORY:
solved = factory.source(*sub_dependencies)
elif factory.type is FactoryType.ASYNC_GENERATOR:
raise UnsupportedFactoryError(
f"Unsupported factory type {factory.type}. "
f"Did you mean to use an async container?",
)
elif factory.type is FactoryType.ASYNC_FACTORY:
raise UnsupportedFactoryError(
f"Unsupported factory type {factory.type}. "
f"Did you mean to use an async container?",
)
elif factory.type is FactoryType.VALUE:
solved = factory.source
elif factory.type is FactoryType.ALIAS:
solved = sub_dependencies[0]
elif factory.type is FactoryType.CONTEXT:
raise NoContextValueError(
f"Value for type {factory.provides.type_hint} is not found "
f"in container context with scope={factory.scope}",
)
else:
raise UnsupportedFactoryError(
f"Unsupported factory type {factory.type}. ",
)
if factory.cache:
self.context[key] = solved
return solved

def get(
self,
dependency_type: type[T],
Expand All @@ -143,14 +96,19 @@ def get(
def _get_unlocked(self, key: DependencyKey) -> Any:
if key in self.context:
return self.context[key]
factory = self.registry.get_factory(key)
if not factory:
compiled = self.registry.get_compiled(key)
if not compiled:
if not self.parent_container:
raise NoFactoryError(key)
return self.parent_container.get(
key.type_hint, key.component,
)
return self._get_from_self(factory, key)
try:
return compiled(self._get_unlocked, self._exits, self.context)
except NoFactoryError as e:
e.add_path(self.registry.get_factory(key))
raise


def close(self) -> None:
errors = []
Expand Down
31 changes: 28 additions & 3 deletions src/dishka/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from typing import Any, NewType, TypeVar, get_args, get_origin

from ._adaptix.type_tools.basic_utils import get_type_vars, is_generic
Expand All @@ -18,14 +18,17 @@
NoFactoryError,
UnknownScopeError,
)
from .factory_compiler import compile_factory
from .provider import BaseProvider


class Registry:
__slots__ = ("scope", "factories")
__slots__ = ("scope", "factories", "compiled", "compiled_async")

def __init__(self, scope: BaseScope):
self.factories: dict[DependencyKey, Factory] = {}
self.compiled: dict[DependencyKey, Callable] = {}
self.compiled_async: dict[DependencyKey, Callable] = {}
self.scope = scope

def add_factory(self, factory: Factory):
Expand All @@ -36,6 +39,28 @@ def add_factory(self, factory: Factory):
self.factories[origin_key] = factory
self.factories[factory.provides] = factory

def get_compiled(self, dependency: DependencyKey) -> Callable | None:
try:
return self.compiled[dependency]
except KeyError:
factory = self.get_factory(dependency)
if not factory:
return None
compiled = compile_factory(factory=factory, is_async=False)
self.compiled[dependency] = compiled
return compiled

def get_compiled_async(self, dependency: DependencyKey) -> Callable | None:
try:
return self.compiled[dependency]
except KeyError:
factory = self.get_factory(dependency)
if not factory:
return None
compiled = compile_factory(factory=factory, is_async=True)
self.compiled[dependency] = compiled
return compiled

def get_factory(self, dependency: DependencyKey) -> Factory | None:
try:
return self.factories[dependency]
Expand Down Expand Up @@ -302,4 +327,4 @@ def build(self):
registries = list(self.registries.values())
if not self.skip_validation:
GraphValidator(registries).validate()
return registries
return tuple(registries)

0 comments on commit 84885fb

Please sign in to comment.