diff --git a/jaraco/context/__init__.py b/jaraco/context/__init__.py index 74457d9..f249be2 100644 --- a/jaraco/context/__init__.py +++ b/jaraco/context/__init__.py @@ -8,8 +8,9 @@ import subprocess import sys import tempfile +import types import urllib.request -from typing import Iterator +from typing import Iterator, TypeVar, Union, Optional, Callable, Type, Tuple if sys.version_info < (3, 12): @@ -17,9 +18,12 @@ else: import tarfile +PathLike = Union[str, os.PathLike] +T = TypeVar('T') + @contextlib.contextmanager -def pushd(dir: str | os.PathLike) -> Iterator[str | os.PathLike]: +def pushd(dir: PathLike) -> Iterator[PathLike]: """ >>> tmp_path = getfixture('tmp_path') >>> with pushd(tmp_path): @@ -37,8 +41,8 @@ def pushd(dir: str | os.PathLike) -> Iterator[str | os.PathLike]: @contextlib.contextmanager def tarball( - url, target_dir: str | os.PathLike | None = None -) -> Iterator[str | os.PathLike]: + url: str, target_dir: str | os.PathLike | None = None +) -> Iterator[PathLike]: """ Get a URL to a tarball, download, extract, yield, then clean up. @@ -89,7 +93,11 @@ def strip_first_component( return member -def _compose(*cmgrs): +CM = TypeVar('CM', bound=contextlib.AbstractContextManager) +"""Type var for context managers.""" + + +def _compose(*cmgrs: Callable[..., CM]) -> Callable[..., CM]: """ Compose any number of dependent context managers into a single one. @@ -126,7 +134,7 @@ def composed(*args, **kwargs): @contextlib.contextmanager -def temp_dir(remover=shutil.rmtree): +def temp_dir(remover: Callable[[str], None] = shutil.rmtree) -> Iterator[str]: """ Create a temporary directory context. Pass a custom remover to override the removal behavior. @@ -145,7 +153,12 @@ def temp_dir(remover=shutil.rmtree): @contextlib.contextmanager -def repo_context(url, branch: str | None = None, quiet: bool = True, dest_ctx=temp_dir): +def repo_context( + url, + branch: str | None = None, + quiet: bool = True, + dest_ctx: Callable[[], contextlib.AbstractContextManager[str]] = temp_dir, +): """ Check out the repo indicated by url. @@ -167,7 +180,7 @@ def repo_context(url, branch: str | None = None, quiet: bool = True, dest_ctx=te yield repo_dir -class ExceptionTrap: +class ExceptionTrap(contextlib.AbstractContextManager): """ A context manager that will catch certain exceptions and provide an indication they occurred. @@ -201,9 +214,13 @@ class ExceptionTrap: False """ - exc_info = None, None, None + exc_info: Tuple[ + Optional[Type[BaseException]], + Optional[BaseException], + Optional[types.TracebackType], + ] = (None, None, None) # Explicitly type the tuple - def __init__(self, exceptions=(Exception,)): + def __init__(self, exceptions: Tuple[Type[BaseException], ...] = (Exception,)): self.exceptions = exceptions def __enter__(self): @@ -231,7 +248,9 @@ def __exit__(self, *exc_info): def __bool__(self): return bool(self.type) - def raises(self, func, *, _test=bool): + def raises( + self, func: Callable[..., T], *, _test: Callable[[ExceptionTrap], bool] = bool + ): """ Wrap func and replace the result with the truth value of the trap (True if an exception occurred). @@ -258,7 +277,7 @@ def wrapper(*args, **kwargs): return wrapper - def passes(self, func): + def passes(self, func: Callable[..., T]) -> Callable[..., bool]: """ Wrap func and replace the result with the truth value of the trap (True if no exception).