Skip to content

Commit

Permalink
Speed up make_simplified_union, fix recursive tuple crash (#15128)
Browse files Browse the repository at this point in the history
Fixes #15192

The following code optimises make_simplified_union in the common case
that there are exact duplicates in the union. In this regard, this is
similar to #15104

There's a behaviour change in one unit test. I think it's good? We'll
see what mypy_primer has to say.

To get this to work, I needed to use partial tuple fallbacks in a couple
places. These could cause crashes anyway.
There were some interesting things going on with recursive type aliases
and type state assumptions

This is about a 25% speedup on the pydantic codebase and about a 2%
speedup on self check (measured with uncompiled mypy)
  • Loading branch information
hauntsaninja committed May 6, 2023
1 parent bfc1a76 commit 7832e1f
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 95 deletions.
6 changes: 4 additions & 2 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def visit_instance(self, left: Instance) -> bool:
# dynamic base classes correctly, see #5456.
return not isinstance(self.right, NoneType)
right = self.right
if isinstance(right, TupleType) and mypy.typeops.tuple_fallback(right).type.is_enum:
if isinstance(right, TupleType) and right.partial_fallback.type.is_enum:
return self._is_subtype(left, mypy.typeops.tuple_fallback(right))
if isinstance(right, Instance):
if type_state.is_cached_subtype_check(self._subtype_kind, left, right):
Expand Down Expand Up @@ -749,7 +749,9 @@ def visit_tuple_type(self, left: TupleType) -> bool:
# for isinstance(x, tuple), though it's unclear why.
return True
return all(self._is_subtype(li, iter_type) for li in left.items)
elif self._is_subtype(mypy.typeops.tuple_fallback(left), right):
elif self._is_subtype(left.partial_fallback, right) and self._is_subtype(
mypy.typeops.tuple_fallback(left), right
):
return True
return False
elif isinstance(right, TupleType):
Expand Down
5 changes: 1 addition & 4 deletions mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,10 +611,7 @@ def test_simplified_union_with_mixed_str_literals(self) -> None:
[fx.lit_str1, fx.lit_str2, fx.lit_str3_inst],
UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst]),
)
self.assert_simplified_union(
[fx.lit_str1, fx.lit_str1, fx.lit_str1_inst],
UnionType([fx.lit_str1, fx.lit_str1_inst]),
)
self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], fx.lit_str1)

def assert_simplified_union(self, original: list[Type], union: Type) -> None:
assert_equal(make_simplified_union(original), union)
Expand Down
159 changes: 70 additions & 89 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,25 +385,6 @@ def callable_corresponding_argument(
return by_name if by_name is not None else by_pos


def simple_literal_value_key(t: ProperType) -> tuple[str, ...] | None:
"""Return a hashable description of simple literal type.
Return None if not a simple literal type.
The return value can be used to simplify away duplicate types in
unions by comparing keys for equality. For now enum, string or
Instance with string last_known_value are supported.
"""
if isinstance(t, LiteralType):
if t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str":
assert isinstance(t.value, str)
return "literal", t.value, t.fallback.type.fullname
if isinstance(t, Instance):
if t.last_known_value is not None and isinstance(t.last_known_value.value, str):
return "instance", t.last_known_value.value, t.type.fullname
return None


def simple_literal_type(t: ProperType | None) -> Instance | None:
"""Extract the underlying fallback Instance type for a simple Literal"""
if isinstance(t, Instance) and t.last_known_value is not None:
Expand All @@ -414,7 +395,6 @@ def simple_literal_type(t: ProperType | None) -> Instance | None:


def is_simple_literal(t: ProperType) -> bool:
"""Fast way to check if simple_literal_value_key() would return a non-None value."""
if isinstance(t, LiteralType):
return t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str"
if isinstance(t, Instance):
Expand Down Expand Up @@ -500,68 +480,80 @@ def make_simplified_union(
def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[Type]:
from mypy.subtypes import is_proper_subtype

removed: set[int] = set()
seen: set[tuple[str, ...]] = set()

# NB: having a separate fast path for Union of Literal and slow path for other things
# would arguably be cleaner, however it breaks down when simplifying the Union of two
# different enum types as try_expanding_sum_type_to_union works recursively and will
# trigger intermediate simplifications that would render the fast path useless
for i, item in enumerate(items):
proper_item = get_proper_type(item)
if i in removed:
continue
# Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169)
k = simple_literal_value_key(proper_item)
if k is not None:
if k in seen:
removed.add(i)
# The first pass through this loop, we check if later items are subtypes of earlier items.
# The second pass through this loop, we check if earlier items are subtypes of later items
# (by reversing the remaining items)
for _direction in range(2):
new_items: list[Type] = []
# seen is a map from a type to its index in new_items
seen: dict[ProperType, int] = {}
unduplicated_literal_fallbacks: set[Instance] | None = None
for ti in items:
proper_ti = get_proper_type(ti)

# UninhabitedType is always redundant
if isinstance(proper_ti, UninhabitedType):
continue

# NB: one would naively expect that it would be safe to skip the slow path
# always for literals. One would be sorely mistaken. Indeed, some simplifications
# such as that of None/Optional when strict optional is false, do require that we
# proceed with the slow path. Thankfully, all literals will have the same subtype
# relationship to non-literal types, so we only need to do that walk for the first
# literal, which keeps the fast path fast even in the presence of a mixture of
# literals and other types.
safe_skip = len(seen) > 0
seen.add(k)
if safe_skip:
continue

# Keep track of the truthiness info for deleted subtypes which can be relevant
cbt = cbf = False
for j, tj in enumerate(items):
proper_tj = get_proper_type(tj)
if (
i == j
# avoid further checks if this item was already marked redundant.
or j in removed
# if the current item is a simple literal then this simplification loop can
# safely skip all other simple literals as two literals will only ever be
# subtypes of each other if they are equal, which is already handled above.
# However, if the current item is not a literal, it might plausibly be a
# supertype of other literals in the union, so we must check them again.
# This is an important optimization as is_proper_subtype is pretty expensive.
or (k is not None and is_simple_literal(proper_tj))
):
continue
# actual redundancy checks (XXX?)
if is_redundant_literal_instance(proper_item, proper_tj) and is_proper_subtype(
tj, item, keep_erased_types=keep_erased, ignore_promotions=True
duplicate_index = -1
# Quickly check if we've seen this type
if proper_ti in seen:
duplicate_index = seen[proper_ti]
elif (
isinstance(proper_ti, LiteralType)
and unduplicated_literal_fallbacks is not None
and proper_ti.fallback in unduplicated_literal_fallbacks
):
# We found a redundant item in the union.
removed.add(j)
cbt = cbt or tj.can_be_true
cbf = cbf or tj.can_be_false
# if deleted subtypes had more general truthiness, use that
if not item.can_be_true and cbt:
items[i] = true_or_false(item)
elif not item.can_be_false and cbf:
items[i] = true_or_false(item)
# This is an optimisation for unions with many LiteralType
# We've already checked for exact duplicates. This means that any super type of
# the LiteralType must be a super type of its fallback. If we've gone through
# the expensive loop below and found no super type for a previous LiteralType
# with the same fallback, we can skip doing that work again and just add the type
# to new_items
pass
else:
# If not, check if we've seen a supertype of this type
for j, tj in enumerate(new_items):
tj = get_proper_type(tj)
# If tj is an Instance with a last_known_value, do not remove proper_ti
# (unless it's an instance with the same last_known_value)
if (
isinstance(tj, Instance)
and tj.last_known_value is not None
and not (
isinstance(proper_ti, Instance)
and tj.last_known_value == proper_ti.last_known_value
)
):
continue

if is_proper_subtype(
proper_ti, tj, keep_erased_types=keep_erased, ignore_promotions=True
):
duplicate_index = j
break
if duplicate_index != -1:
# If deleted subtypes had more general truthiness, use that
orig_item = new_items[duplicate_index]
if not orig_item.can_be_true and ti.can_be_true:
new_items[duplicate_index] = true_or_false(orig_item)
elif not orig_item.can_be_false and ti.can_be_false:
new_items[duplicate_index] = true_or_false(orig_item)
else:
# We have a non-duplicate item, add it to new_items
seen[proper_ti] = len(new_items)
new_items.append(ti)
if isinstance(proper_ti, LiteralType):
if unduplicated_literal_fallbacks is None:
unduplicated_literal_fallbacks = set()
unduplicated_literal_fallbacks.add(proper_ti.fallback)

return [items[i] for i in range(len(items)) if i not in removed]
items = new_items
if len(items) <= 1:
break
items.reverse()

return items


def _get_type_special_method_bool_ret_type(t: Type) -> Type | None:
Expand Down Expand Up @@ -992,17 +984,6 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool
return False


def is_redundant_literal_instance(general: ProperType, specific: ProperType) -> bool:
if not isinstance(general, Instance) or general.last_known_value is None:
return True
if isinstance(specific, Instance) and specific.last_known_value == general.last_known_value:
return True
if isinstance(specific, UninhabitedType):
return True

return False


def separate_union_literals(t: UnionType) -> tuple[Sequence[LiteralType], Sequence[Type]]:
"""Separate literals from other members in a union type."""
literal_items = []
Expand Down
16 changes: 16 additions & 0 deletions test-data/unit/check-type-aliases.test
Original file line number Diff line number Diff line change
Expand Up @@ -1043,3 +1043,19 @@ class C(Generic[T]):
def test(cls) -> None:
cls.attr
[builtins fixtures/classmethod.pyi]

[case testRecursiveAliasTuple]
from typing_extensions import Literal, TypeAlias
from typing import Tuple, Union

Expr: TypeAlias = Union[
Tuple[Literal[123], int],
Tuple[Literal[456], "Expr"],
]

def eval(e: Expr) -> int:
if e[0] == 123:
return e[1]
elif e[0] == 456:
return -eval(e[1])
[builtins fixtures/dict.pyi]

0 comments on commit 7832e1f

Please sign in to comment.