From b56468c27aadf90da890ef006d563b1acf495145 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Fri, 11 Nov 2022 18:48:40 -0700 Subject: [PATCH] Fix get_type_hints() on x-module inherited TypedDict in 3.9 and 3.10 (#94) --- src/_typed_dict_test_helper.py | 10 ++++++++++ src/test_typing_extensions.py | 18 +++++++++++++++++- src/typing_extensions.py | 7 ++++++- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/_typed_dict_test_helper.py b/src/_typed_dict_test_helper.py index 396a94fe..7ffc5e1d 100644 --- a/src/_typed_dict_test_helper.py +++ b/src/_typed_dict_test_helper.py @@ -4,5 +4,15 @@ from typing_extensions import TypedDict +# this class must not be imported into test_typing_extensions.py at top level, otherwise +# the test_get_type_hints_cross_module_subclass test will pass for the wrong reason +class _DoNotImport: + pass + + +class Foo(TypedDict): + a: _DoNotImport + + class FooGeneric(TypedDict, Generic[T]): a: Optional[T] diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index f7c68101..770389b1 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -30,7 +30,7 @@ from typing_extensions import clear_overloads, get_overloads, overload from typing_extensions import NamedTuple from typing_extensions import override -from _typed_dict_test_helper import FooGeneric +from _typed_dict_test_helper import Foo, FooGeneric # Flags used to mark tests that only apply after a specific # version of the typing module. @@ -41,6 +41,10 @@ # 3.11 makes runtime type checks (_type_check) more lenient. TYPING_3_11_0 = sys.version_info[:3] >= (3, 11, 0) +# https://github.com/python/cpython/pull/27017 was backported into some 3.9 and 3.10 +# versions, but not all +HAS_FORWARD_MODULE = "module" in inspect.signature(typing._type_check).parameters + class BaseTestCase(TestCase): def assertIsSubclass(self, cls, class_or_tuple, msg=None): @@ -1774,6 +1778,10 @@ class Point2DGeneric(Generic[T], TypedDict): b: T +class Bar(Foo): + b: int + + class BarGeneric(FooGeneric[T], total=False): b: int @@ -1978,6 +1986,14 @@ class PointDict3D(PointDict2D, total=False): assert is_typeddict(PointDict2D) is True assert is_typeddict(PointDict3D) is True + @skipUnless(HAS_FORWARD_MODULE, "ForwardRef.__forward_module__ was added in 3.9") + def test_get_type_hints_cross_module_subclass(self): + self.assertNotIn("_DoNotImport", globals()) + self.assertEqual( + {k: v.__name__ for k, v in get_type_hints(Bar).items()}, + {'a': "_DoNotImport", 'b': "int"} + ) + def test_get_type_hints_generic(self): self.assertEqual( get_type_hints(BarGeneric), diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 2d7a82fc..9553cdfa 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -2,6 +2,7 @@ import collections import collections.abc import functools +import inspect import operator import sys import types as _types @@ -728,6 +729,8 @@ def _typeddict_new(*args, total=True, **kwargs): _typeddict_new.__text_signature__ = ('($cls, _typename, _fields=None,' ' /, *, total=True, **kwargs)') + _TAKES_MODULE = "module" in inspect.signature(typing._type_check).parameters + class _TypedDictMeta(type): def __init__(cls, name, bases, ns, total=True): super().__init__(name, bases, ns) @@ -753,8 +756,10 @@ def __new__(cls, name, bases, ns, total=True): annotations = {} own_annotations = ns.get('__annotations__', {}) msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type" + kwds = {"module": tp_dict.__module__} if _TAKES_MODULE else {} own_annotations = { - n: typing._type_check(tp, msg) for n, tp in own_annotations.items() + n: typing._type_check(tp, msg, **kwds) + for n, tp in own_annotations.items() } required_keys = set() optional_keys = set()