Skip to content

Commit

Permalink
Generic typing for register methods in pkg_resources
Browse files Browse the repository at this point in the history
  • Loading branch information
Avasam authored and abravalheri committed May 24, 2024
1 parent 2aff6e4 commit 3ce6d3f
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions pkg_resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
)


T = TypeVar("T")
_T = TypeVar("_T")
# Type aliases
_NestedStr = Union[str, Iterable[Union[str, Iterable["_NestedStr"]]]]
_InstallerType = Callable[["Requirement"], Optional["Distribution"]]
Expand All @@ -118,7 +118,12 @@
_MetadataType = Optional["IResourceProvider"]
# Any object works, but let's indicate we expect something like a module (optionally has __loader__ or __file__)
_ModuleLike = Union[object, types.ModuleType]
_AdapterType = Callable[..., Any] # Incomplete
_ProviderFactoryType = Callable[[_ModuleLike], "IResourceProvider"]
_DistFinderType = Callable[[_T, str, bool], Iterable["Distribution"]]
_NSHandlerType = Callable[[_T, str, str, types.ModuleType], Optional[str]]
_AdapterT = TypeVar(
"_AdapterT", _DistFinderType[Any], _ProviderFactoryType, _NSHandlerType[Any]
)


# Use _typeshed.importlib.LoaderProtocol once available https://github.com/python/typeshed/pull/11890
Expand All @@ -142,7 +147,7 @@ class PEP440Warning(RuntimeWarning):
_state_vars: Dict[str, str] = {}


def _declare_state(vartype: str, varname: str, initial_value: T) -> T:
def _declare_state(vartype: str, varname: str, initial_value: _T) -> _T:
_state_vars[varname] = vartype
return initial_value

Expand Down Expand Up @@ -377,7 +382,7 @@ class UnknownExtra(ResolutionError):
"""Distribution doesn't have an "extra feature" of the given name"""


_provider_factories: Dict[Type[_ModuleLike], _AdapterType] = {}
_provider_factories: Dict[Type[_ModuleLike], _ProviderFactoryType] = {}

PY_MAJOR = '{}.{}'.format(*sys.version_info)
EGG_DIST = 3
Expand All @@ -388,7 +393,7 @@ class UnknownExtra(ResolutionError):


def register_loader_type(
loader_type: Type[_ModuleLike], provider_factory: _AdapterType
loader_type: Type[_ModuleLike], provider_factory: _ProviderFactoryType
):
"""Register `provider_factory` to make providers for `loader_type`
Expand Down Expand Up @@ -2097,12 +2102,12 @@ def __init__(self, importer: zipimport.zipimporter):
self._setup_prefix()


_distribution_finders: Dict[
type, Callable[[object, str, bool], Iterable["Distribution"]]
] = _declare_state('dict', '_distribution_finders', {})
_distribution_finders: Dict[type, _DistFinderType[Any]] = _declare_state(
'dict', '_distribution_finders', {}
)


def register_finder(importer_type: type, distribution_finder: _AdapterType):
def register_finder(importer_type: Type[_T], distribution_finder: _DistFinderType[_T]):
"""Register `distribution_finder` to find distributions in sys.path items
`importer_type` is the type or class of a PEP 302 "Importer" (sys.path item
Expand Down Expand Up @@ -2276,15 +2281,17 @@ def resolve_egg_link(path):

register_finder(importlib.machinery.FileFinder, find_on_path)

_namespace_handlers: Dict[
type, Callable[[object, str, str, types.ModuleType], Optional[str]]
] = _declare_state('dict', '_namespace_handlers', {})
_namespace_handlers: Dict[type, _NSHandlerType[Any]] = _declare_state(
'dict', '_namespace_handlers', {}
)
_namespace_packages: Dict[Optional[str], List[str]] = _declare_state(
'dict', '_namespace_packages', {}
)


def register_namespace_handler(importer_type: type, namespace_handler: _AdapterType):
def register_namespace_handler(
importer_type: Type[_T], namespace_handler: _NSHandlerType[_T]
):
"""Register `namespace_handler` to declare namespace packages
`importer_type` is the type or class of a PEP 302 "Importer" (sys.path item
Expand Down Expand Up @@ -2429,9 +2436,9 @@ def fixup_namespace_packages(path_item: str, parent: Optional[str] = None):


def file_ns_handler(
importer: Optional[importlib.abc.PathEntryFinder],
path_item,
packageName,
importer: object,
path_item: "StrPath",
packageName: str,
module: types.ModuleType,
):
"""Compute an ns-package subpath for a filesystem or zipfile importer"""
Expand All @@ -2454,7 +2461,7 @@ def file_ns_handler(


def null_ns_handler(
importer: Optional[importlib.abc.PathEntryFinder],
importer: object,
path_item: Optional[str],
packageName: Optional[str],
module: Optional[_ModuleLike],
Expand Down Expand Up @@ -3321,7 +3328,7 @@ def _always_object(classes):
return classes


def _find_adapter(registry: Mapping[type, _AdapterType], ob: object):
def _find_adapter(registry: Mapping[type, _AdapterT], ob: object) -> _AdapterT:
"""Return an adapter factory for `ob` from `registry`"""
types = _always_object(inspect.getmro(getattr(ob, '__class__', type(ob))))
for t in types:
Expand Down

0 comments on commit 3ce6d3f

Please sign in to comment.