diff --git a/src/sphinx_autodoc_typehints/__init__.py b/src/sphinx_autodoc_typehints/__init__.py index 2311133d..997c93e1 100644 --- a/src/sphinx_autodoc_typehints/__init__.py +++ b/src/sphinx_autodoc_typehints/__init__.py @@ -4,6 +4,7 @@ import re import sys import textwrap +import types from ast import FunctionDef, Module, stmt from typing import Any, AnyStr, Callable, ForwardRef, NewType, TypeVar, get_type_hints @@ -21,8 +22,16 @@ _LOGGER = logging.getLogger(__name__) _PYDATA_ANNOTATIONS = {"Any", "AnyStr", "Callable", "ClassVar", "Literal", "NoReturn", "Optional", "Tuple", "Union"} +# types has a bunch of things like ModuleType where ModuleType.__module__ is +# "builtins" and ModuleType.__name__ is "module", so we have to check for this. +_TYPES_DICT = {getattr(types, name): name for name in types.__all__} +# Prefer FunctionType to LambdaType (they are synonymous) +_TYPES_DICT[types.FunctionType] = "FunctionType" + def get_annotation_module(annotation: Any) -> str: + if annotation in _TYPES_DICT: + return "types" if annotation is None: return "builtins" is_new_type = sys.version_info >= (3, 10) and isinstance(annotation, NewType) @@ -35,17 +44,22 @@ def get_annotation_module(annotation: Any) -> str: raise ValueError(f"Cannot determine the module of {annotation}") +def _is_newtype(annotation: Any) -> bool: + if sys.version_info < (3, 10): + return inspect.isfunction(annotation) and hasattr(annotation, "__supertype__") + else: + return isinstance(annotation, NewType) + + def get_annotation_class_name(annotation: Any, module: str) -> str: # Special cases if annotation is None: return "None" - elif annotation is Any: - return "Any" - elif annotation is AnyStr: + if annotation is AnyStr: return "AnyStr" - elif (sys.version_info < (3, 10) and inspect.isfunction(annotation) and hasattr(annotation, "__supertype__")) or ( - sys.version_info >= (3, 10) and isinstance(annotation, NewType) - ): + if annotation in _TYPES_DICT: + return _TYPES_DICT[annotation] + if _is_newtype(annotation): return "NewType" if getattr(annotation, "__qualname__", None): diff --git a/tests/test_sphinx_autodoc_typehints.py b/tests/test_sphinx_autodoc_typehints.py index 2e213dd1..6d2728b1 100644 --- a/tests/test_sphinx_autodoc_typehints.py +++ b/tests/test_sphinx_autodoc_typehints.py @@ -4,6 +4,7 @@ import pathlib import re import sys +import types import typing from functools import cmp_to_key from io import StringIO @@ -134,6 +135,10 @@ def __getitem__(self, params): [ pytest.param(str, "builtins", "str", (), id="str"), pytest.param(None, "builtins", "None", (), id="None"), + pytest.param(ModuleType, "types", "ModuleType", (), id="ModuleType"), + pytest.param(FunctionType, "types", "FunctionType", (), id="FunctionType"), + pytest.param(types.CodeType, "types", "CodeType", (), id="CodeType"), + pytest.param(types.CoroutineType, "types", "CoroutineType", (), id="CoroutineType"), pytest.param(Any, "typing", "Any", (), id="Any"), pytest.param(AnyStr, "typing", "AnyStr", (), id="AnyStr"), pytest.param(Dict, "typing", "Dict", (), id="Dict"), @@ -170,9 +175,10 @@ def __getitem__(self, params): ], ) def test_parse_annotation(annotation: Any, module: str, class_name: str, args: tuple[Any, ...]) -> None: - assert get_annotation_module(annotation) == module - assert get_annotation_class_name(annotation, module) == class_name - assert get_annotation_args(annotation, module, class_name) == args + got_mod = get_annotation_module(annotation) + got_cls = get_annotation_class_name(annotation, module) + got_args = get_annotation_args(annotation, module, class_name) + assert (got_mod, got_cls, got_args) == (module, class_name, args) @pytest.mark.parametrize( @@ -181,6 +187,8 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t (str, ":py:class:`str`"), (int, ":py:class:`int`"), (StringIO, ":py:class:`~io.StringIO`"), + (FunctionType, ":py:class:`~types.FunctionType`"), + (ModuleType, ":py:class:`~types.ModuleType`"), (type(None), ":py:obj:`None`"), (type, ":py:class:`type`"), (collections.abc.Callable, ":py:class:`~collections.abc.Callable`"), diff --git a/whitelist.txt b/whitelist.txt index be15dbcd..d32e390f 100644 --- a/whitelist.txt +++ b/whitelist.txt @@ -4,6 +4,7 @@ autouse backfill conf contravariant +Coroutine cpython csv dedent @@ -31,7 +32,9 @@ iterdir kwonlyargs libs metaclass +ModuleType multiline +newtype nptyping param parametrized