Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#98 Add 'WithParents' #195

Merged
merged 21 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ lint.ignore = [
"PLR0913",
"UP038",
"TCH001",
"SIM103",
"PLR1704",
Tishka17 marked this conversation as resolved.
Show resolved Hide resolved
]

[lint.per-file-ignores]
Expand Down
21 changes: 20 additions & 1 deletion docs/provider/provide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,23 @@ Also, if an error occurs during process handling (inside the ``with`` block), it
def p(self) -> AnyOf[A, AProtocol]:
return A()

It works similar to :ref:`alias`.
It works similar to :ref:`alias`.

* Do you want to get dependencies by parents? Use ``WithParents`` as a result hint:

.. code-block:: python

from dishka import WithParents, provide, Provider, Scope

class A(Protocol): ...
class AImpl(A): ...

class MyProvider(Provider):
scope=Scope.APP

@provide
def a(self) -> WithParents[A]:
return A()

This is similar to ``AnyOf[AImpl, A]``. The following parents are ignored: ``type``, ``object``, ``Enum``, ``ABC``, ``ABCMeta``, ``Generic``, ``Protocol``, ``Exception``, ``BaseException``

2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,3 @@ dependencies = [
"Homepage" = "https://github.com/reagento/dishka"
"Documentation" = "https://dishka.readthedocs.io/en/stable/"
"Bug Tracker" = "https://github.com/reagento/dishka/issues"


192 changes: 192 additions & 0 deletions src/dishka/entities/with_parents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from abc import ABC, ABCMeta
from enum import Enum
from types import GenericAlias
from typing import (
TYPE_CHECKING,
Final,
Generic,
Protocol,
TypeAlias,
TypeVar,
TypeVarTuple,
cast,
)

from dishka._adaptix.common import TypeHint
from dishka._adaptix.type_tools import (
get_generic_args,
get_type_vars,
is_generic,
is_named_tuple_class,
is_typed_dict_class,
strip_alias,
)
from dishka.entities.provides_marker import ProvideMultiple

IGNORE_TYPES: Final = (
type,
object,
Enum,
ABC,
ABCMeta,
Generic,
Protocol,
Exception,
BaseException,
)
TypeVarsMap: TypeAlias = dict[TypeVar | TypeVarTuple, TypeHint]


def has_orig_bases(obj: TypeHint) -> bool:
return hasattr(obj, "__orig_bases__")


def is_type_var_tuple(obj: TypeHint) -> bool:
return getattr(obj, "__typing_is_unpacked_typevartuple__", False)


def is_ignore_type(origin_obj: TypeHint) -> bool:
Tishka17 marked this conversation as resolved.
Show resolved Hide resolved
if origin_obj in IGNORE_TYPES:
return True
if is_named_tuple_class(origin_obj):
return True
if is_typed_dict_class(origin_obj):
return True
return False


def get_filled_arguments(obj: TypeHint) -> list[TypeHint]:
filled_arguments = []
for arg in get_generic_args(obj):
if isinstance(arg, (TypeVar, TypeVarTuple)):
continue
if is_type_var_tuple(arg):
continue
filled_arguments.append(arg)
return filled_arguments


def create_type_vars_map(obj: TypeHint) -> TypeVarsMap:
origin_obj = strip_alias(obj)
if not get_type_vars(origin_obj):
return {}

type_vars_map = {}
type_vars = list(get_type_vars(origin_obj))
filled_arguments = get_filled_arguments(obj)

if not filled_arguments or not type_vars:
return {}

reversed_arguments = False
while True:
if len(type_vars) == 0:
break
type_var = type_vars[0]
if isinstance(type_var, TypeVar):
del type_vars[0]
type_vars_map[type_var] = filled_arguments.pop(0)
else:
if len(type_vars) == 1:
if reversed_arguments:
filled_arguments.reverse()
type_vars_map[type_var] = filled_arguments
break
type_vars.reverse()
filled_arguments.reverse()
reversed_arguments = not reversed_arguments

return cast(TypeVarsMap, type_vars_map)


def create_generic_class(
origin_obj: TypeHint,
type_vars_map: TypeVarsMap,
) -> TypeHint | None:
if not is_generic(origin_obj):
return None

generic_args = []
for type_var in get_type_vars(origin_obj):
arg = type_vars_map[type_var]
if isinstance(arg, list):
generic_args.extend(arg)
else:
generic_args.append(arg)
return origin_obj[*generic_args]


def recursion_get_parents_for_generic_class(
obj: TypeHint,
parents: list[TypeHint],
type_vars_map: TypeVarsMap,
) -> None:
origin_obj = strip_alias(obj)
if is_ignore_type(origin_obj):
return

if not has_orig_bases(origin_obj):
parents.extend(get_parents_for_mro(origin_obj))
return

for obj in origin_obj.__orig_bases__:
origin_obj = strip_alias(obj)
if is_ignore_type(origin_obj):
continue

type_vars_map.update(create_type_vars_map(obj))
parents.append(create_generic_class(origin_obj, type_vars_map) or obj)
recursion_get_parents_for_generic_class(
obj,
parents,
type_vars_map.copy(),
)


def get_parents_for_mro(obj: TypeHint) -> list[TypeHint]:
return [
obj_ for obj_ in obj.mro()
if not is_ignore_type(strip_alias(obj_))
]


def get_parents(obj: TypeHint) -> list[TypeHint]:
if is_ignore_type(strip_alias(obj)):
raise ValueError(f"The starting class {obj!r} is in ignored types")

if isinstance(obj, GenericAlias):
type_vars_map = create_type_vars_map(obj)
parents = [
create_generic_class(
origin_obj=strip_alias(obj),
type_vars_map=type_vars_map,
) or obj,
]
recursion_get_parents_for_generic_class(
obj=obj,
parents=parents,
type_vars_map=type_vars_map,
)
elif has_orig_bases(obj):
parents = [obj]
recursion_get_parents_for_generic_class(
obj=obj,
parents=parents,
type_vars_map={},
)
else:
parents = get_parents_for_mro(obj)
return parents


if TYPE_CHECKING:
from typing import Union as WithParents
else:
class WithParents:
Tishka17 marked this conversation as resolved.
Show resolved Hide resolved
def __class_getitem__(
cls, item: TypeHint,
) -> TypeHint | ProvideMultiple:
parents = get_parents(item)
if len(parents) > 1:
return ProvideMultiple(parents)
return parents[0]
141 changes: 141 additions & 0 deletions tests/unit/container/test_with_parents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from abc import ABC
from typing import Generic, NamedTuple, Protocol, TypeVar, TypeVarTuple

from dishka import make_container, Provider, Scope
from dishka.entities.with_parents import WithParents
from dishka.exceptions import NoFactoryError


def test_simple_inheritance() -> None:
class A1: ...
class A2(A1): ...
class A3(A2): ...

provider = Provider(scope=Scope.APP)
provider.provide(lambda: 1, provides=WithParents[A3])
Tishka17 marked this conversation as resolved.
Show resolved Hide resolved
container = make_container(provider)
assert container.get(A3) == 1
assert container.get(A2) == 1
assert container.get(A1) == 1


def test_ignore_parent_type() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parametrize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

class A1(Protocol): ...
class A2(ABC): ...
class A3(NamedTuple): ...

provider = Provider(scope=Scope.APP)
provider.provide(lambda: 1, provides=WithParents[A1])
provider.provide(lambda: 1, provides=WithParents[A2])
try:
provider.provide(lambda: 1, provides=WithParents[A3])
except ValueError:
pass
else:
assert False
container = make_container(provider)

try:
container.get(Protocol)
except NoFactoryError:
pass
else:
assert False

try:
container.get(ABC)
except NoFactoryError:
pass
else:
assert False



def test_type_var() -> None:
T = TypeVar('T')

class A1(Generic[T]): ...
class A2(A1[str]): ...

provider = Provider(scope=Scope.APP)
provider.provide(lambda: 1, provides=WithParents[A2])

container = make_container(provider)

assert container.get(A2) == 1
assert container.get(A1[str]) == 1


def test_type_var_tuple() -> None:
Ts = TypeVarTuple('Ts')

class A1(Generic[*Ts]): ...
class A2(A1[str, int, type]): ...

provider = Provider(scope=Scope.APP)
provider.provide(lambda: 1, provides=WithParents[A2])

container = make_container(provider)

assert container.get(A2) == 1
assert container.get(A1[str, int, type]) == 1


def test_type_var_and_type_var_tuple() -> None:
Tishka17 marked this conversation as resolved.
Show resolved Hide resolved
T = TypeVar('T')
B = TypeVar('B')
Ts = TypeVarTuple('Ts')

class A1(Generic[T, *Ts]): ...
class A2(A1[str, int, type]): ...

class B1(Generic[*Ts, T], str): ...
class B2(B1[int, tuple[str, ...], type]): ...

class C1(Generic[B, *Ts, T]): ...
class C2(C1[int, type, str, tuple[str, ...]]): ...


provider = Provider(scope=Scope.APP)
provider.provide(lambda: 1, provides=WithParents[A2])
provider.provide(lambda: 1, provides=WithParents[B2])
provider.provide(lambda: 1, provides=WithParents[C2])

container = make_container(provider)

assert container.get(A2) == 1
assert container.get(A1[str, int, type]) == 1

assert container.get(B2) == 1
assert container.get(B1[int, tuple[str, ...], type]) == 1

assert container.get(C2) == 1
assert container.get(C1[int, type, str, tuple[str, ...]]) == 1


def test_deep_inheritance() -> None:
T = TypeVar('T')
Ts = TypeVarTuple('Ts')

class A1(Generic[*Ts]): ...
class A2(A1[*Ts], Generic[*Ts, T]): ...

class B1: ...
class B2(B1): ...
class B3(B2): ...

class C1(Generic[T], B3): ...
class D1(A2[int, type, str], C1[str]): ...

provider = Provider(scope=Scope.APP)
provider.provide(lambda: 1, provides=WithParents[D1])
container = make_container(provider)

assert container.get(D1) == 1
assert container.get(A2[int, type, str]) == 1
assert container.get(A1[int, type]) == 1
assert container.get(C1[str]) == 1
assert container.get(B3) == 1
assert container.get(B2) == 1
assert container.get(B1) == 1
assert container.get(D1) == 1
Loading