diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 7c94067..f7820ec 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -20,7 +20,7 @@ from unittest import TestCase, main, skipUnless, skipIf from unittest.mock import patch import typing -from typing import TypeVar, Optional, Union, AnyStr +from typing import Optional, Union, AnyStr from typing import T, KT, VT # Not in __all__. from typing import Tuple, List, Set, Dict, Iterable, Iterator, Callable from typing import Generic @@ -36,7 +36,7 @@ from typing_extensions import assert_type, get_type_hints, get_origin, get_args, get_original_bases from typing_extensions import clear_overloads, get_overloads, overload from typing_extensions import NamedTuple -from typing_extensions import override, deprecated, Buffer, TypeAliasType +from typing_extensions import override, deprecated, Buffer, TypeAliasType, TypeVar from _typed_dict_test_helper import Foo, FooGeneric # Flags used to mark tests that only apply after a specific @@ -3306,6 +3306,7 @@ def test_basic_plain(self): P = ParamSpec('P') self.assertEqual(P, P) self.assertIsInstance(P, ParamSpec) + self.assertEqual(P.__name__, 'P') # Should be hashable hash(P) @@ -4375,10 +4376,153 @@ class GenericNamedTuple(NamedTuple, Generic[T]): self.assertEqual(CallNamedTuple.__orig_bases__, (NamedTuple,)) +class TypeVarTests(BaseTestCase): + def test_basic_plain(self): + T = TypeVar('T') + # T equals itself. + self.assertEqual(T, T) + # T is an instance of TypeVar + self.assertIsInstance(T, TypeVar) + self.assertEqual(T.__name__, 'T') + self.assertEqual(T.__constraints__, ()) + self.assertIs(T.__bound__, None) + self.assertIs(T.__covariant__, False) + self.assertIs(T.__contravariant__, False) + self.assertIs(T.__infer_variance__, False) + + def test_attributes(self): + T_bound = TypeVar('T_bound', bound=int) + self.assertEqual(T_bound.__name__, 'T_bound') + self.assertEqual(T_bound.__constraints__, ()) + self.assertIs(T_bound.__bound__, int) + + T_constraints = TypeVar('T_constraints', int, str) + self.assertEqual(T_constraints.__name__, 'T_constraints') + self.assertEqual(T_constraints.__constraints__, (int, str)) + self.assertIs(T_constraints.__bound__, None) + + T_co = TypeVar('T_co', covariant=True) + self.assertEqual(T_co.__name__, 'T_co') + self.assertIs(T_co.__covariant__, True) + self.assertIs(T_co.__contravariant__, False) + self.assertIs(T_co.__infer_variance__, False) + + T_contra = TypeVar('T_contra', contravariant=True) + self.assertEqual(T_contra.__name__, 'T_contra') + self.assertIs(T_contra.__covariant__, False) + self.assertIs(T_contra.__contravariant__, True) + self.assertIs(T_contra.__infer_variance__, False) + + T_infer = TypeVar('T_infer', infer_variance=True) + self.assertEqual(T_infer.__name__, 'T_infer') + self.assertIs(T_infer.__covariant__, False) + self.assertIs(T_infer.__contravariant__, False) + self.assertIs(T_infer.__infer_variance__, True) + + def test_typevar_instance_type_error(self): + T = TypeVar('T') + with self.assertRaises(TypeError): + isinstance(42, T) + + def test_typevar_subclass_type_error(self): + T = TypeVar('T') + with self.assertRaises(TypeError): + issubclass(int, T) + with self.assertRaises(TypeError): + issubclass(T, int) + + def test_constrained_error(self): + with self.assertRaises(TypeError): + X = TypeVar('X', int) + X + + def test_union_unique(self): + X = TypeVar('X') + Y = TypeVar('Y') + self.assertNotEqual(X, Y) + self.assertEqual(Union[X], X) + self.assertNotEqual(Union[X], Union[X, Y]) + self.assertEqual(Union[X, X], X) + self.assertNotEqual(Union[X, int], Union[X]) + self.assertNotEqual(Union[X, int], Union[int]) + self.assertEqual(Union[X, int].__args__, (X, int)) + self.assertEqual(Union[X, int].__parameters__, (X,)) + self.assertIs(Union[X, int].__origin__, Union) + + if hasattr(types, "UnionType"): + def test_or(self): + X = TypeVar('X') + # use a string because str doesn't implement + # __or__/__ror__ itself + self.assertEqual(X | "x", Union[X, "x"]) + self.assertEqual("x" | X, Union["x", X]) + # make sure the order is correct + self.assertEqual(get_args(X | "x"), (X, typing.ForwardRef("x"))) + self.assertEqual(get_args("x" | X), (typing.ForwardRef("x"), X)) + + def test_union_constrained(self): + A = TypeVar('A', str, bytes) + self.assertNotEqual(Union[A, str], Union[A]) + + def test_repr(self): + self.assertEqual(repr(T), '~T') + self.assertEqual(repr(KT), '~KT') + self.assertEqual(repr(VT), '~VT') + self.assertEqual(repr(AnyStr), '~AnyStr') + T_co = TypeVar('T_co', covariant=True) + self.assertEqual(repr(T_co), '+T_co') + T_contra = TypeVar('T_contra', contravariant=True) + self.assertEqual(repr(T_contra), '-T_contra') + + def test_no_redefinition(self): + self.assertNotEqual(TypeVar('T'), TypeVar('T')) + self.assertNotEqual(TypeVar('T', int, str), TypeVar('T', int, str)) + + def test_cannot_subclass(self): + with self.assertRaises(TypeError): + class V(TypeVar): pass + T = TypeVar("T") + with self.assertRaises(TypeError): + class V(T): pass + + def test_cannot_instantiate_vars(self): + with self.assertRaises(TypeError): + TypeVar('A')() + + def test_bound_errors(self): + with self.assertRaises(TypeError): + TypeVar('X', bound=Union) + with self.assertRaises(TypeError): + TypeVar('X', str, float, bound=Employee) + with self.assertRaisesRegex(TypeError, + r"Bound must be a type\. Got \(1, 2\)\."): + TypeVar('X', bound=(1, 2)) + + # Technically we could run it on later versions of 3.7 and 3.8, + # but that's not worth the effort. + @skipUnless(TYPING_3_9_0, "Fix was not backported") + def test_missing__name__(self): + # See bpo-39942 + code = ("import typing\n" + "T = typing.TypeVar('T')\n" + ) + exec(code, {}) + + def test_no_bivariant(self): + with self.assertRaises(ValueError): + TypeVar('T', covariant=True, contravariant=True) + + def test_cannot_combine_explicit_and_infer(self): + with self.assertRaises(ValueError): + TypeVar('T', covariant=True, infer_variance=True) + with self.assertRaises(ValueError): + TypeVar('T', contravariant=True, infer_variance=True) + + class TypeVarLikeDefaultsTests(BaseTestCase): def test_typevar(self): T = typing_extensions.TypeVar('T', default=int) - typing_T = TypeVar('T') + typing_T = typing.TypeVar('T') self.assertEqual(T.__default__, int) self.assertIsInstance(T, typing_extensions.TypeVar) self.assertIsInstance(T, typing.TypeVar) diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 6fd0f24..ff5aefe 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -1374,6 +1374,8 @@ def __call__(self, name, *constraints, bound=None, else: typevar = typing.TypeVar(name, *constraints, bound=bound, covariant=covariant, contravariant=contravariant) + if infer_variance and (covariant or contravariant): + raise ValueError("Variance cannot be specified with infer_variance.") typevar.__infer_variance__ = infer_variance _set_default(typevar, default) @@ -1392,6 +1394,9 @@ class TypeVar(metaclass=_TypeVarMeta): __module__ = 'typing' + def __init_subclass__(cls) -> None: + raise TypeError(f"type '{__name__}.TypeVar' is not an acceptable base type") + # Python 3.10+ has PEP 612 if hasattr(typing, 'ParamSpecArgs'): @@ -1481,6 +1486,9 @@ class ParamSpec(metaclass=_ParamSpecMeta): __module__ = 'typing' + def __init_subclass__(cls) -> None: + raise TypeError(f"type '{__name__}.ParamSpec' is not an acceptable base type") + # 3.7-3.9 else: