From 77cfdc25d9e6857c2ab0380924a1dcf72cc1c037 Mon Sep 17 00:00:00 2001 From: Gobot1234 Date: Mon, 20 Feb 2023 23:27:49 +0000 Subject: [PATCH] Backport https://github.com/python/cpython/pull/31628 --- CHANGELOG.md | 6 + src/test_typing_extensions.py | 28 ++++- src/typing_extensions.py | 215 ++++++++++++++++++++-------------- 3 files changed, 163 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d330a0f..97828b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# Unreleased + +- Backport `Protocol.__init__` behaviour from python 3.11. (see + python/cpython#31628, by Adrian Garcia Badaracco). Patch by James + Hilton-Balfe (@Gobot1234). + # Release 4.5.0 (February 14, 2023) - Runtime support for PEP 702, adding `typing_extensions.deprecated`. Patch diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 208382a..09825d7 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -1427,6 +1427,32 @@ class PG(Protocol[T]): pass class CG(PG[T]): pass self.assertIsInstance(CG[int](), CG) + def test_protocol_defining_init_does_not_get_overridden(self): + # check that P.__init__ doesn't get clobbered + # see https://bugs.python.org/issue44807 + + class P(Protocol): + x: int + def __init__(self, x: int) -> None: + self.x = x + class C: pass + + c = C() + P.__init__(c, 1) + self.assertEqual(c.x, 1) + + def test_concrete_class_inheriting_init_from_protocol(self): + class P(Protocol): + x: int + def __init__(self, x: int) -> None: + self.x = x + + class C(P): pass + + c = C(1) + self.assertIsInstance(c, C) + self.assertEqual(c.x, 1) + def test_cannot_instantiate_abstract(self): @runtime class P(Protocol): @@ -3302,7 +3328,7 @@ def test_typing_extensions_defers_when_possible(self): if sys.version_info < (3, 10): exclude |= {'get_args', 'get_origin'} if sys.version_info < (3, 11): - exclude |= {'final', 'NamedTuple', 'Any'} + exclude |= {'final', 'NamedTuple', 'Any', 'Protocol'} for item in typing_extensions.__all__: if item not in exclude and hasattr(typing, item): self.assertIs( diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 6ae0c34..d0f967b 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -93,6 +93,8 @@ def _check_generic(cls, parameters, elen=_marker): """Check correct count for parameters of a generic cls (internal helper). This gives a nice error message in case of count mismatch. """ + if cls is Protocol and not elen: + return if not elen: raise TypeError(f"{cls} is not a generic class") if elen is _marker: @@ -143,6 +145,11 @@ def _collect_type_vars(types, typevar_types=None): tvars.extend([t for t in t.__parameters__ if t not in tvars]) return tuple(tvars) +def _caller(depth=1, default='__main__'): + try: + return sys._getframe(depth + 1).f_globals.get('__name__', default) + except (AttributeError, ValueError): # For platforms without _getframe() + return None NoReturn = typing.NoReturn @@ -457,162 +464,206 @@ def _maybe_adjust_parameters(cls): cls.__parameters__ = tuple(tvars) -# 3.8+ -if hasattr(typing, 'Protocol'): +# 3.11+ +if sys.version_info >= (3, 11): # 3.8 has Protocol but it doesn't preserve __init__ Protocol = typing.Protocol -# 3.7 + else: + _TYPING_INTERNALS = ['__parameters__', '__orig_bases__', '__orig_class__', + '_is_protocol', '_is_runtime_protocol'] + + _SPECIAL_NAMES = ['__abstractmethods__', '__annotations__', '__dict__', '__doc__', + '__init__', '__module__', '__new__', '__slots__', + '__subclasshook__', '__weakref__', '__class_getitem__'] + + # These special attributes will be not collected as protocol members. + _EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS + _SPECIAL_NAMES + ['_MutableMapping__marker'] + + + def _get_protocol_attrs(cls): + """Collect protocol members from a protocol class objects. + This includes names actually defined in the class dictionary, as well + as names that appear in annotations. Special names (above) are skipped. + """ + attrs = set() + for base in cls.__mro__[:-1]: # without object + if base.__name__ in ('Protocol', 'Generic'): + continue + annotations = getattr(base, '__annotations__', {}) + for attr in list(base.__dict__.keys()) + list(annotations.keys()): + if not attr.startswith('_abc_') and attr not in _EXCLUDED_ATTRIBUTES: + attrs.add(attr) + return attrs + + + def _is_callable_members_only(cls): + # PEP 544 prohibits using issubclass() with protocols that have non-method members. + return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls)) - def _no_init(self, *args, **kwargs): - if type(self)._is_protocol: + + def _no_init_or_replace_init(self, *args, **kwargs): + cls = type(self) + + if cls._is_protocol: raise TypeError('Protocols cannot be instantiated') - class _ProtocolMeta(abc.ABCMeta): # noqa: B024 - # This metaclass is a bit unfortunate and exists only because of the lack - # of __instancehook__. + # Already using a custom `__init__`. No need to calculate correct + # `__init__` to call. This can lead to RecursionError. See bpo-45121. + if cls.__init__ is not _no_init_or_replace_init: + return + + # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. + # The first instantiation of the subclass will call `_no_init_or_replace_init` which + # searches for a proper new `__init__` in the MRO. The new `__init__` + # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent + # instantiation of the protocol subclass will thus use the new + # `__init__` and no longer call `_no_init_or_replace_init`. + for base in cls.__mro__: + init = base.__dict__.get('__init__', _no_init_or_replace_init) + if init is not _no_init_or_replace_init: + cls.__init__ = init + break + else: + # should not happen + cls.__init__ = object.__init__ + + cls.__init__(self, *args, **kwargs) + + + def _allow_reckless_class_checks(depth=3): + """Allow instance and class checks for special stdlib modules. + The abc and functools modules indiscriminately call isinstance() and + issubclass() on the whole MRO of a user class, which may contain protocols. + """ + return _caller(depth) in {'abc', 'functools', None} + + + _PROTO_ALLOWLIST = { + 'collections.abc': [ + 'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable', + 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', + ], + 'contextlib': ['AbstractContextManager', 'AbstractAsyncContextManager'], + } + + + class _ProtocolMeta(abc.ABCMeta): + # This metaclass is really unfortunate and exists only because of + # the lack of __instancehook__. def __instancecheck__(cls, instance): # We need this method for situations where attributes are # assigned in __init__. + if ( + getattr(cls, '_is_protocol', False) and + not getattr(cls, '_is_runtime_protocol', False) and + not _allow_reckless_class_checks(depth=2) + ): + raise TypeError("Instance and class checks can only be used with" + " @runtime_checkable protocols") + if ((not getattr(cls, '_is_protocol', False) or - _is_callable_members_only(cls)) and + _is_callable_members_only(cls)) and issubclass(instance.__class__, cls)): return True if cls._is_protocol: if all(hasattr(instance, attr) and - (not callable(getattr(cls, attr, None)) or + # All *methods* can be blocked by setting them to None. + (not callable(getattr(cls, attr, None)) or getattr(instance, attr) is not None) - for attr in _get_protocol_attrs(cls)): + for attr in _get_protocol_attrs(cls)): return True return super().__instancecheck__(instance) - class Protocol(metaclass=_ProtocolMeta): - # There is quite a lot of overlapping code with typing.Generic. - # Unfortunately it is hard to avoid this while these live in two different - # modules. The duplicated code will be removed when Protocol is moved to typing. - """Base class for protocol classes. Protocol classes are defined as:: + class Protocol(typing.Generic, metaclass=_ProtocolMeta): + """Base class for protocol classes. + Protocol classes are defined as:: class Proto(Protocol): def meth(self) -> int: ... - Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example:: - class C: def meth(self) -> int: return 0 - def func(x: Proto) -> int: return x.meth() - func(C()) # Passes static type check - See PEP 544 for details. Protocol classes decorated with - @typing_extensions.runtime act as simple-minded runtime protocol that checks + @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. - Protocol classes can be generic, they are defined as:: - class GenProto(Protocol[T]): def meth(self) -> T: ... """ __slots__ = () _is_protocol = True - - def __new__(cls, *args, **kwds): - if cls is Protocol: - raise TypeError("Type Protocol cannot be instantiated; " - "it can only be used as a base class") - return super().__new__(cls) - - @typing._tp_cache - def __class_getitem__(cls, params): - if not isinstance(params, tuple): - params = (params,) - if not params and cls is not typing.Tuple: - raise TypeError( - f"Parameter list to {cls.__qualname__}[...] cannot be empty") - msg = "Parameters to generic types must be types." - params = tuple(typing._type_check(p, msg) for p in params) # noqa - if cls is Protocol: - # Generic can only be subscripted with unique type variables. - if not all(isinstance(p, typing.TypeVar) for p in params): - i = 0 - while isinstance(params[i], typing.TypeVar): - i += 1 - raise TypeError( - "Parameters to Protocol[...] must all be type variables." - f" Parameter {i + 1} is {params[i]}") - if len(set(params)) != len(params): - raise TypeError( - "Parameters to Protocol[...] must all be unique") - else: - # Subscripting a regular Generic subclass. - _check_generic(cls, params, len(cls.__parameters__)) - return typing._GenericAlias(cls, params) + _is_runtime_protocol = False def __init_subclass__(cls, *args, **kwargs): - if '__orig_bases__' in cls.__dict__: - error = typing.Generic in cls.__orig_bases__ - else: - error = typing.Generic in cls.__bases__ - if error: - raise TypeError("Cannot inherit from plain Generic") - _maybe_adjust_parameters(cls) + super().__init_subclass__(*args, **kwargs) # Determine if this is a protocol or a concrete subclass. - if not cls.__dict__.get('_is_protocol', None): + if not cls.__dict__.get('_is_protocol', False): cls._is_protocol = any(b is Protocol for b in cls.__bases__) # Set (or override) the protocol subclass hook. def _proto_hook(other): - if not cls.__dict__.get('_is_protocol', None): + if not cls.__dict__.get('_is_protocol', False): return NotImplemented + + # First, perform various sanity checks. if not getattr(cls, '_is_runtime_protocol', False): - if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']: + if _allow_reckless_class_checks(): return NotImplemented raise TypeError("Instance and class checks can only be used with" - " @runtime protocols") + " @runtime_checkable protocols") if not _is_callable_members_only(cls): - if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']: + if _allow_reckless_class_checks(): return NotImplemented raise TypeError("Protocols with non-method members" " don't support issubclass()") if not isinstance(other, type): - # Same error as for issubclass(1, int) + # Same error message as for issubclass(1, int). raise TypeError('issubclass() arg 1 must be a class') + + # Second, perform the actual structural compatibility check. for attr in _get_protocol_attrs(cls): for base in other.__mro__: + # Check if the members appears in the class dictionary... if attr in base.__dict__: if base.__dict__[attr] is None: return NotImplemented break + + # ...or in annotations, if it is a sub-protocol. annotations = getattr(base, '__annotations__', {}) - if (isinstance(annotations, typing.Mapping) and + if (isinstance(annotations, collections.abc.Mapping) and attr in annotations and - isinstance(other, _ProtocolMeta) and - other._is_protocol): + issubclass(other, typing.Generic) and other._is_protocol): break else: return NotImplemented return True + if '__subclasshook__' not in cls.__dict__: cls.__subclasshook__ = _proto_hook - # We have nothing more to do for non-protocols. + # We have nothing more to do for non-protocols... if not cls._is_protocol: return - # Check consistency of bases. + # ... otherwise check consistency of bases, and prohibit instantiation. for base in cls.__bases__: if not (base in (object, typing.Generic) or - base.__module__ == 'collections.abc' and - base.__name__ in _PROTO_WHITELIST or - isinstance(base, _ProtocolMeta) and base._is_protocol): + base.__module__ in _PROTO_ALLOWLIST and + base.__name__ in _PROTO_ALLOWLIST[base.__module__] or + issubclass(base, typing.Generic) and base._is_protocol): raise TypeError('Protocols can only inherit from other' - f' protocols, got {repr(base)}') - cls.__init__ = _no_init + ' protocols, got %r' % base) + if cls.__init__ is Protocol.__init__: + cls.__init__ = _no_init_or_replace_init # 3.8+ @@ -2229,12 +2280,6 @@ def wrapper(*args, **kwargs): if sys.version_info >= (3, 11): NamedTuple = typing.NamedTuple else: - def _caller(): - try: - return sys._getframe(2).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): # For platforms without _getframe() - return None - def _make_nmtuple(name, types, module, defaults=()): fields = [n for n, t in types] annotations = {n: typing._type_check(t, f"field {n} annotation must be a type")