Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Literal type math #11992

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
ParamSpecExpr,
ArgKind, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE,
)
from mypy.literals import literal
from mypy.literals import literal, try_literal_math
from mypy import nodes
from mypy import operators
import mypy.checker
Expand Down Expand Up @@ -2570,6 +2570,27 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
any_type = AnyType(TypeOfAny.from_another_any, source_any=right_type)
return any_type, any_type

# STEP 0:
# We support `Literal` type math. For example, we want to reveal `1 + 3`
# as `Literal[4]`. So, we check if we have literal expressions first.
# We consider this to be the fast path, we move on if it is not a literal.
# But, operations on literal types are not processed further.

if isinstance(context, OpExpr) and isinstance(left_type, (LiteralType, Instance)):
fallback_left_type = (
left_type.fallback
if isinstance(left_type, LiteralType)
else left_type
)
literal_result = try_literal_math(
context.op,
left_expr, left_type,
right_expr, right_type,
fallback=fallback_left_type,
)
if literal_result is not None:
return literal_result, literal_result

# STEP 1:
# We start by getting the __op__ and __rop__ methods, if they exist.

Expand Down
65 changes: 65 additions & 0 deletions mypy/literals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import operator
from typing import Optional, Union, Any, Tuple, Iterable
from typing_extensions import Final

from mypy.types import Type, LiteralType, Instance, LiteralValue, get_proper_type
from mypy.nodes import (
Expression, ComparisonExpr, OpExpr, MemberExpr, UnaryExpr, StarExpr, IndexExpr, LITERAL_YES,
LITERAL_NO, NameExpr, LITERAL_TYPE, IntExpr, FloatExpr, ComplexExpr, StrExpr, BytesExpr,
Expand Down Expand Up @@ -246,3 +248,66 @@ def visit_temp_node(self, e: TempNode) -> None:


_hasher: Final = _Hasher()


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name literal of this literals.py file is different from the Literal in typing, which is an old and annoying conflict. I am not sure whether they should put in same file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the old and new functionality are quite distinct, so having in the same file could be confusing. One option would be to move the old code into a new module with a different name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JukkaL Any suggestions on the new name for the old part?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps simple_expr.py? Or even literal_expr.py, though I don't think that helps much.

_SUPPORTED_LITERAL_OPERATIONS: Final = {
int: ('+', '-', '*', '//'), # `/` returns `float`
str: ('+',),
bool: ('and', 'or'),
}
_OP_FUNCTIONS: Final = {
'+': operator.add,
'-': operator.sub,
'*': operator.mul,
'//': operator.floordiv,
'and': operator.and_,
'or': operator.or_,
}


def try_literal_math(
op: str,
left_expr: Expression, left_type: Type,
right_expr: Expression, right_type: Type,
*,
fallback: Instance,
) -> Optional[Instance]:
left_literal = _get_literal_value(left_expr, left_type)
if left_literal is None:
return None
right_literal = _get_literal_value(right_expr, right_type)
if right_literal is None:
return None

lit_type = type(left_literal)
if (lit_type != type(right_literal)
or lit_type not in _SUPPORTED_LITERAL_OPERATIONS
or op not in _SUPPORTED_LITERAL_OPERATIONS[lit_type]):
return None

op_method = _OP_FUNCTIONS[op]
try:
new_value = op_method(left_literal, right_literal)
except Exception: # We catch any possible problem: overflow, type error, etc.
return None
else:
return fallback.copy_modified(last_known_value=LiteralType(
new_value,
fallback=fallback,
))


def _get_literal_value(expr: Expression, typ: Type) -> Optional[LiteralValue]:
# We can work with a literal type:
typ = get_proper_type(typ)
if isinstance(typ, LiteralType):
return typ.value
elif isinstance(typ, Instance) and typ.last_known_value:
return typ.last_known_value.value

# Or a literal node (`True` / `False` are already `Literal[True] | [False]`):
if isinstance(expr, (IntExpr, StrExpr)):
return expr.value

# It is not a literal:
return None
2 changes: 1 addition & 1 deletion test-data/unit/check-inference-context.test
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ class B: pass

[case testInferLambdaTypeUsingContext]
x : str = (lambda x: x + 1)(1) # E: Incompatible types in assignment (expression has type "int", variable has type "str")
reveal_type((lambda x, y: x + y)(1, 2)) # N: Revealed type is "builtins.int"
reveal_type((lambda x, y: x + y)(1, 2)) # N: Revealed type is "Literal[3]?"
(lambda x, y: x + y)(1, "") # E: Unsupported operand types for + ("int" and "str")
(lambda *, x, y: x + y)(x=1, y="") # E: Unsupported operand types for + ("int" and "str")
reveal_type((lambda s, i: s)(i=0, s='x')) # N: Revealed type is "Literal['x']?"
Expand Down
180 changes: 172 additions & 8 deletions test-data/unit/check-literal.test
Original file line number Diff line number Diff line change
Expand Up @@ -1745,19 +1745,19 @@ c: Literal[4]
d: Literal['foo']
e: str

reveal_type(a + a) # N: Revealed type is "builtins.int"
reveal_type(a + a) # N: Revealed type is "Literal[6]?"
reveal_type(a + b) # N: Revealed type is "builtins.int"
reveal_type(b + a) # N: Revealed type is "builtins.int"
reveal_type(a + 1) # N: Revealed type is "builtins.int"
reveal_type(1 + a) # N: Revealed type is "builtins.int"
reveal_type(a + c) # N: Revealed type is "builtins.int"
reveal_type(c + a) # N: Revealed type is "builtins.int"
reveal_type(a + 1) # N: Revealed type is "Literal[4]?"
reveal_type(1 + a) # N: Revealed type is "Literal[4]?"
reveal_type(a + c) # N: Revealed type is "Literal[7]?"
reveal_type(c + a) # N: Revealed type is "Literal[7]?"

reveal_type(d + d) # N: Revealed type is "builtins.str"
reveal_type(d + d) # N: Revealed type is "Literal['foofoo']?"
reveal_type(d + e) # N: Revealed type is "builtins.str"
reveal_type(e + d) # N: Revealed type is "builtins.str"
reveal_type(d + 'foo') # N: Revealed type is "builtins.str"
reveal_type('foo' + d) # N: Revealed type is "builtins.str"
reveal_type(d + 'foo') # N: Revealed type is "Literal['foofoo']?"
reveal_type('foo' + d) # N: Revealed type is "Literal['foofoo']?"

reveal_type(a.__add__(b)) # N: Revealed type is "builtins.int"
reveal_type(b.__add__(a)) # N: Revealed type is "builtins.int"
Expand Down Expand Up @@ -3346,3 +3346,167 @@ def incorrect_return2() -> Union[Tuple[Literal[True], int], Tuple[Literal[False]
else:
return (bool(), 'oops') # E: Incompatible return value type (got "Tuple[bool, str]", expected "Union[Tuple[Literal[True], int], Tuple[Literal[False], str]]")
[builtins fixtures/bool.pyi]


# Literal math
# ============

[case testLiteralIntMath]
from typing_extensions import Literal, Final

reveal_type(1 + 2) # N: Revealed type is "Literal[3]?"
reveal_type(2 + 1) # N: Revealed type is "Literal[3]?"

reveal_type(2 // 2) # N: Revealed type is "Literal[1]?"
reveal_type(5 // 2) # N: Revealed type is "Literal[2]?"

reveal_type(1 + 2 + 3) # N: Revealed type is "Literal[6]?"
reveal_type(2 + 2 * 2) # N: Revealed type is "Literal[6]?"
reveal_type(2 * 2 + 2) # N: Revealed type is "Literal[6]?"
reveal_type(100 - 2 + 1) # N: Revealed type is "Literal[99]?"

a: Literal[3]
b: Literal[4]
c: Final = 5

reveal_type(a + b) # N: Revealed type is "Literal[7]?"
reveal_type(a + 1 + b + 1) # N: Revealed type is "Literal[9]?"
reveal_type(1 + a + 1 + b) # N: Revealed type is "Literal[9]?"
reveal_type(a + c) # N: Revealed type is "Literal[8]?"
reveal_type(c + a) # N: Revealed type is "Literal[8]?"
reveal_type(c + c) # N: Revealed type is "Literal[10]?"

i: int

reveal_type(a + i) # N: Revealed type is "builtins.int"
reveal_type(i + a) # N: Revealed type is "builtins.int"
reveal_type(i * 2) # N: Revealed type is "builtins.int"
reveal_type(2 * i) # N: Revealed type is "builtins.int"
reveal_type(i // 2) # N: Revealed type is "builtins.int"
reveal_type(2 // i) # N: Revealed type is "builtins.int"
reveal_type(i - 2) # N: Revealed type is "builtins.int"
reveal_type(2 - i) # N: Revealed type is "builtins.int"
reveal_type(i - c) # N: Revealed type is "builtins.int"

# Corner cases:

reveal_type(9223372036854775807 + 9223372036854775807) # N: Revealed type is "Literal[18446744073709551614]?"
reveal_type(9223372036854775807 * 9223372036854775807) # N: Revealed type is "Literal[85070591730234615847396907784232501249]?"

reveal_type(1 // 0) # N: Revealed type is "builtins.int"
reveal_type(1 + 0) # N: Revealed type is "Literal[1]?"
[builtins fixtures/primitives.pyi]


[case testLiteralStrMath]
from typing_extensions import Literal, Final

reveal_type('a' + 'b') # N: Revealed type is "Literal['ab']?"
reveal_type('b' + 'a') # N: Revealed type is "Literal['ba']?"

a: Literal['a']
b: Literal['b']
c: Final = 'c'

reveal_type(a + '!') # N: Revealed type is "Literal['a!']?"
reveal_type('!' + a) # N: Revealed type is "Literal['!a']?"
reveal_type(a + b + c) # N: Revealed type is "Literal['abc']?"
reveal_type(c + b + a) # N: Revealed type is "Literal['cba']?"
reveal_type(a + '!' + b + '?' + c) # N: Revealed type is "Literal['a!b?c']?"
reveal_type(c + '1' + a + '2' + b) # N: Revealed type is "Literal['c1a2b']?"

s: str

reveal_type(s + 'a') # N: Revealed type is "builtins.str"
reveal_type('a' + s) # N: Revealed type is "builtins.str"
reveal_type(s + a) # N: Revealed type is "builtins.str"
reveal_type(a + s) # N: Revealed type is "builtins.str"
reveal_type(s + c) # N: Revealed type is "builtins.str"
reveal_type(c + s) # N: Revealed type is "builtins.str"

# Corner cases:

reveal_type('a' + '') # N: Revealed type is "Literal['a']?"
reveal_type(a + '') # N: Revealed type is "Literal['a']?"
reveal_type('' + '') # N: Revealed type is "Literal['']?"
[builtins fixtures/primitives.pyi]


[case testLiteralBytesMath]
from typing_extensions import Literal, Final

reveal_type(b'a' + b'b') # N: Revealed type is "Literal[b'ab']?"
reveal_type(b'b' + b'a') # N: Revealed type is "Literal[b'ba']?"

a: Literal[b'a']
b: Literal[b'b']
c: Final = b'c'

reveal_type(a + b'!') # N: Revealed type is "Literal[b'a!']?"
reveal_type(b'!' + a) # N: Revealed type is "Literal[b'!a']?"
reveal_type(a + b + c) # N: Revealed type is "Literal[b'abc']?"
reveal_type(c + b + a) # N: Revealed type is "Literal[b'cba']?"
reveal_type(a + b'!' + b + b'?' + c) # N: Revealed type is "Literal[b'a!b?c']?"
reveal_type(c + b'1' + a + b'2' + b) # N: Revealed type is "Literal[b'c1a2b']?"

s: bytes

reveal_type(s + b'a') # N: Revealed type is "builtins.bytes"
reveal_type(b'a' + s) # N: Revealed type is "builtins.bytes"
reveal_type(s + a) # N: Revealed type is "builtins.bytes"
reveal_type(a + s) # N: Revealed type is "builtins.bytes"
reveal_type(s + c) # N: Revealed type is "builtins.bytes"
reveal_type(c + s) # N: Revealed type is "builtins.bytes"

# Corner cases:

reveal_type(b'a' + b'') # N: Revealed type is "Literal[b'a']?"
reveal_type(a + b'') # N: Revealed type is "Literal[b'a']?"
reveal_type(b'' + b'') # N: Revealed type is "Literal[b'']?"
[builtins fixtures/primitives.pyi]


[case testLiteralBoolMath]
from typing_extensions import Literal, Final

reveal_type(True or False) # N: Revealed type is "Literal[True]?"
reveal_type(False or True) # N: Revealed type is "Literal[True]"
reveal_type(True or True) # N: Revealed type is "Literal[True]?"
reveal_type(False or False) # N: Revealed type is "Literal[False]"

reveal_type(True and False) # N: Revealed type is "Literal[False]"
reveal_type(False and True) # N: Revealed type is "Literal[False]?"
reveal_type(True and True) # N: Revealed type is "Literal[True]"
reveal_type(False and False) # N: Revealed type is "Literal[False]?"

reveal_type(True and False and True) # N: Revealed type is "Literal[False]"
reveal_type(True or False or False) # N: Revealed type is "Literal[True]?"

t: Literal[True]
f: Literal[False]
c: Final = True

reveal_type(t or False) # N: Revealed type is "Literal[True]"
reveal_type(False or t) # N: Revealed type is "Literal[True]"
reveal_type(t or f) # N: Revealed type is "Literal[True]"
reveal_type(t and c) # N: Revealed type is "Literal[True]"

b: bool

reveal_type(True or b) # N: Revealed type is "Literal[True]?"
reveal_type(b and True) # N: Revealed type is "builtins.bool"
reveal_type(b and False) # N: Revealed type is "Literal[False]"
[builtins fixtures/primitives.pyi]


[case testLiteralMathLoopContext]
def func1(loop_count: int):
x = 1
reveal_type(x) # N: Revealed type is "builtins.int"
x = x + 1
reveal_type(x) # N: Revealed type is "builtins.int"

for _ in [1, 2, 3]:
x = x + 1
reveal_type(x) # N: Revealed type is "builtins.int"
[builtins fixtures/primitives.pyi]
3 changes: 3 additions & 0 deletions test-data/unit/fixtures/primitives.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class int:
def __init__(self, x: object = ..., base: int = ...) -> None: pass
def __add__(self, i: int) -> int: pass
def __rmul__(self, x: int) -> int: pass
def __sub__(self, x: int) -> int: pass
def __floordiv__(self, x: int) -> int: pass
class float:
def __float__(self) -> float: pass
class complex: pass
Expand All @@ -28,6 +30,7 @@ class str(Sequence[str]):
def __getitem__(self, item: int) -> str: pass
def format(self, *args, **kwargs) -> str: pass
class bytes(Sequence[int]):
def __add__(self, s: bytes) -> bytes: pass
def __iter__(self) -> Iterator[int]: pass
def __contains__(self, other: object) -> bool: pass
def __getitem__(self, item: int) -> int: pass
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/typexport-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class float:
class type: pass
class str: pass
[out]
OpExpr(3) : builtins.int
OpExpr(3) : Literal[3]?
OpExpr(4) : builtins.float
OpExpr(5) : builtins.float
OpExpr(6) : builtins.float
Expand Down