Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Gobot1234 committed Feb 20, 2023
1 parent 8dcd899 commit 77cfdc2
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 86 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Unreleased

- Backport `Protocol.__init__` behaviour from python 3.11. (see
python/cpython#31628, by Adrian Garcia Badaracco). Patch by James
Hilton-Balfe (@Gobot1234).

# Release 4.5.0 (February 14, 2023)

- Runtime support for PEP 702, adding `typing_extensions.deprecated`. Patch
Expand Down
28 changes: 27 additions & 1 deletion src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,32 @@ class PG(Protocol[T]): pass
class CG(PG[T]): pass
self.assertIsInstance(CG[int](), CG)

def test_protocol_defining_init_does_not_get_overridden(self):
# check that P.__init__ doesn't get clobbered
# see https://bugs.python.org/issue44807

class P(Protocol):
x: int
def __init__(self, x: int) -> None:
self.x = x
class C: pass

c = C()
P.__init__(c, 1)
self.assertEqual(c.x, 1)

def test_concrete_class_inheriting_init_from_protocol(self):
class P(Protocol):
x: int
def __init__(self, x: int) -> None:
self.x = x

class C(P): pass

c = C(1)
self.assertIsInstance(c, C)
self.assertEqual(c.x, 1)

def test_cannot_instantiate_abstract(self):
@runtime
class P(Protocol):
Expand Down Expand Up @@ -3302,7 +3328,7 @@ def test_typing_extensions_defers_when_possible(self):
if sys.version_info < (3, 10):
exclude |= {'get_args', 'get_origin'}
if sys.version_info < (3, 11):
exclude |= {'final', 'NamedTuple', 'Any'}
exclude |= {'final', 'NamedTuple', 'Any', 'Protocol'}
for item in typing_extensions.__all__:
if item not in exclude and hasattr(typing, item):
self.assertIs(
Expand Down
215 changes: 130 additions & 85 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def _check_generic(cls, parameters, elen=_marker):
"""Check correct count for parameters of a generic cls (internal helper).
This gives a nice error message in case of count mismatch.
"""
if cls is Protocol and not elen:
return
if not elen:
raise TypeError(f"{cls} is not a generic class")
if elen is _marker:
Expand Down Expand Up @@ -143,6 +145,11 @@ def _collect_type_vars(types, typevar_types=None):
tvars.extend([t for t in t.__parameters__ if t not in tvars])
return tuple(tvars)

def _caller(depth=1, default='__main__'):
try:
return sys._getframe(depth + 1).f_globals.get('__name__', default)
except (AttributeError, ValueError): # For platforms without _getframe()
return None

NoReturn = typing.NoReturn

Expand Down Expand Up @@ -457,162 +464,206 @@ def _maybe_adjust_parameters(cls):
cls.__parameters__ = tuple(tvars)


# 3.8+
if hasattr(typing, 'Protocol'):
# 3.11+
if sys.version_info >= (3, 11): # 3.8 has Protocol but it doesn't preserve __init__
Protocol = typing.Protocol
# 3.7

else:
_TYPING_INTERNALS = ['__parameters__', '__orig_bases__', '__orig_class__',
'_is_protocol', '_is_runtime_protocol']

_SPECIAL_NAMES = ['__abstractmethods__', '__annotations__', '__dict__', '__doc__',
'__init__', '__module__', '__new__', '__slots__',
'__subclasshook__', '__weakref__', '__class_getitem__']

# These special attributes will be not collected as protocol members.
_EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS + _SPECIAL_NAMES + ['_MutableMapping__marker']


def _get_protocol_attrs(cls):
"""Collect protocol members from a protocol class objects.
This includes names actually defined in the class dictionary, as well
as names that appear in annotations. Special names (above) are skipped.
"""
attrs = set()
for base in cls.__mro__[:-1]: # without object
if base.__name__ in ('Protocol', 'Generic'):
continue
annotations = getattr(base, '__annotations__', {})
for attr in list(base.__dict__.keys()) + list(annotations.keys()):
if not attr.startswith('_abc_') and attr not in _EXCLUDED_ATTRIBUTES:
attrs.add(attr)
return attrs


def _is_callable_members_only(cls):
# PEP 544 prohibits using issubclass() with protocols that have non-method members.
return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))

def _no_init(self, *args, **kwargs):
if type(self)._is_protocol:

def _no_init_or_replace_init(self, *args, **kwargs):
cls = type(self)

if cls._is_protocol:
raise TypeError('Protocols cannot be instantiated')

class _ProtocolMeta(abc.ABCMeta): # noqa: B024
# This metaclass is a bit unfortunate and exists only because of the lack
# of __instancehook__.
# Already using a custom `__init__`. No need to calculate correct
# `__init__` to call. This can lead to RecursionError. See bpo-45121.
if cls.__init__ is not _no_init_or_replace_init:
return

# Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
# The first instantiation of the subclass will call `_no_init_or_replace_init` which
# searches for a proper new `__init__` in the MRO. The new `__init__`
# replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
# instantiation of the protocol subclass will thus use the new
# `__init__` and no longer call `_no_init_or_replace_init`.
for base in cls.__mro__:
init = base.__dict__.get('__init__', _no_init_or_replace_init)
if init is not _no_init_or_replace_init:
cls.__init__ = init
break
else:
# should not happen
cls.__init__ = object.__init__

cls.__init__(self, *args, **kwargs)


def _allow_reckless_class_checks(depth=3):
"""Allow instance and class checks for special stdlib modules.
The abc and functools modules indiscriminately call isinstance() and
issubclass() on the whole MRO of a user class, which may contain protocols.
"""
return _caller(depth) in {'abc', 'functools', None}


_PROTO_ALLOWLIST = {
'collections.abc': [
'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable',
'Hashable', 'Sized', 'Container', 'Collection', 'Reversible',
],
'contextlib': ['AbstractContextManager', 'AbstractAsyncContextManager'],
}


class _ProtocolMeta(abc.ABCMeta):
# This metaclass is really unfortunate and exists only because of
# the lack of __instancehook__.
def __instancecheck__(cls, instance):
# We need this method for situations where attributes are
# assigned in __init__.
if (
getattr(cls, '_is_protocol', False) and
not getattr(cls, '_is_runtime_protocol', False) and
not _allow_reckless_class_checks(depth=2)
):
raise TypeError("Instance and class checks can only be used with"
" @runtime_checkable protocols")

if ((not getattr(cls, '_is_protocol', False) or
_is_callable_members_only(cls)) and
_is_callable_members_only(cls)) and
issubclass(instance.__class__, cls)):
return True
if cls._is_protocol:
if all(hasattr(instance, attr) and
(not callable(getattr(cls, attr, None)) or
# All *methods* can be blocked by setting them to None.
(not callable(getattr(cls, attr, None)) or
getattr(instance, attr) is not None)
for attr in _get_protocol_attrs(cls)):
for attr in _get_protocol_attrs(cls)):
return True
return super().__instancecheck__(instance)

class Protocol(metaclass=_ProtocolMeta):
# There is quite a lot of overlapping code with typing.Generic.
# Unfortunately it is hard to avoid this while these live in two different
# modules. The duplicated code will be removed when Protocol is moved to typing.
"""Base class for protocol classes. Protocol classes are defined as::

class Protocol(typing.Generic, metaclass=_ProtocolMeta):
"""Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize
structural subtyping (static duck-typing), for example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with
@typing_extensions.runtime act as simple-minded runtime protocol that checks
@typing.runtime_checkable act as simple-minded runtime protocols that check
only the presence of given attributes, ignoring their type signatures.
Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
"""
__slots__ = ()
_is_protocol = True

def __new__(cls, *args, **kwds):
if cls is Protocol:
raise TypeError("Type Protocol cannot be instantiated; "
"it can only be used as a base class")
return super().__new__(cls)

@typing._tp_cache
def __class_getitem__(cls, params):
if not isinstance(params, tuple):
params = (params,)
if not params and cls is not typing.Tuple:
raise TypeError(
f"Parameter list to {cls.__qualname__}[...] cannot be empty")
msg = "Parameters to generic types must be types."
params = tuple(typing._type_check(p, msg) for p in params) # noqa
if cls is Protocol:
# Generic can only be subscripted with unique type variables.
if not all(isinstance(p, typing.TypeVar) for p in params):
i = 0
while isinstance(params[i], typing.TypeVar):
i += 1
raise TypeError(
"Parameters to Protocol[...] must all be type variables."
f" Parameter {i + 1} is {params[i]}")
if len(set(params)) != len(params):
raise TypeError(
"Parameters to Protocol[...] must all be unique")
else:
# Subscripting a regular Generic subclass.
_check_generic(cls, params, len(cls.__parameters__))
return typing._GenericAlias(cls, params)
_is_runtime_protocol = False

def __init_subclass__(cls, *args, **kwargs):
if '__orig_bases__' in cls.__dict__:
error = typing.Generic in cls.__orig_bases__
else:
error = typing.Generic in cls.__bases__
if error:
raise TypeError("Cannot inherit from plain Generic")
_maybe_adjust_parameters(cls)
super().__init_subclass__(*args, **kwargs)

# Determine if this is a protocol or a concrete subclass.
if not cls.__dict__.get('_is_protocol', None):
if not cls.__dict__.get('_is_protocol', False):
cls._is_protocol = any(b is Protocol for b in cls.__bases__)

# Set (or override) the protocol subclass hook.
def _proto_hook(other):
if not cls.__dict__.get('_is_protocol', None):
if not cls.__dict__.get('_is_protocol', False):
return NotImplemented

# First, perform various sanity checks.
if not getattr(cls, '_is_runtime_protocol', False):
if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']:
if _allow_reckless_class_checks():
return NotImplemented
raise TypeError("Instance and class checks can only be used with"
" @runtime protocols")
" @runtime_checkable protocols")
if not _is_callable_members_only(cls):
if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']:
if _allow_reckless_class_checks():
return NotImplemented
raise TypeError("Protocols with non-method members"
" don't support issubclass()")
if not isinstance(other, type):
# Same error as for issubclass(1, int)
# Same error message as for issubclass(1, int).
raise TypeError('issubclass() arg 1 must be a class')

# Second, perform the actual structural compatibility check.
for attr in _get_protocol_attrs(cls):
for base in other.__mro__:
# Check if the members appears in the class dictionary...
if attr in base.__dict__:
if base.__dict__[attr] is None:
return NotImplemented
break

# ...or in annotations, if it is a sub-protocol.
annotations = getattr(base, '__annotations__', {})
if (isinstance(annotations, typing.Mapping) and
if (isinstance(annotations, collections.abc.Mapping) and
attr in annotations and
isinstance(other, _ProtocolMeta) and
other._is_protocol):
issubclass(other, typing.Generic) and other._is_protocol):
break
else:
return NotImplemented
return True

if '__subclasshook__' not in cls.__dict__:
cls.__subclasshook__ = _proto_hook

# We have nothing more to do for non-protocols.
# We have nothing more to do for non-protocols...
if not cls._is_protocol:
return

# Check consistency of bases.
# ... otherwise check consistency of bases, and prohibit instantiation.
for base in cls.__bases__:
if not (base in (object, typing.Generic) or
base.__module__ == 'collections.abc' and
base.__name__ in _PROTO_WHITELIST or
isinstance(base, _ProtocolMeta) and base._is_protocol):
base.__module__ in _PROTO_ALLOWLIST and
base.__name__ in _PROTO_ALLOWLIST[base.__module__] or
issubclass(base, typing.Generic) and base._is_protocol):
raise TypeError('Protocols can only inherit from other'
f' protocols, got {repr(base)}')
cls.__init__ = _no_init
' protocols, got %r' % base)
if cls.__init__ is Protocol.__init__:
cls.__init__ = _no_init_or_replace_init


# 3.8+
Expand Down Expand Up @@ -2229,12 +2280,6 @@ def wrapper(*args, **kwargs):
if sys.version_info >= (3, 11):
NamedTuple = typing.NamedTuple
else:
def _caller():
try:
return sys._getframe(2).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError): # For platforms without _getframe()
return None

def _make_nmtuple(name, types, module, defaults=()):
fields = [n for n, t in types]
annotations = {n: typing._type_check(t, f"field {n} annotation must be a type")
Expand Down

0 comments on commit 77cfdc2

Please sign in to comment.