Skip to content

Commit

Permalink
Support new union syntax in stubs always in runtime context
Browse files Browse the repository at this point in the history
Previously it only worked when the target Python version was 3.10.

Work on #9880.
  • Loading branch information
JukkaL committed Jul 6, 2021
1 parent 56618b9 commit dc96406
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 39 deletions.
10 changes: 3 additions & 7 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2023,13 +2023,9 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
Handle all kinds of assignment statements (simple, indexed, multiple).
"""
with self.enter_final_context(s.is_final_def):
self.check_assignment(s.lvalues[-1], s.rvalue, s.type is None, s.new_syntax)

if s.is_alias_def:
# We do this mostly for compatibility with old semantic analyzer.
# TODO: should we get rid of this?
self.store_type(s.lvalues[-1], self.expr_checker.accept(s.rvalue))
if not (s.is_alias_def and self.is_stub):
with self.enter_final_context(s.is_final_def):
self.check_assignment(s.lvalues[-1], s.rvalue, s.type is None, s.new_syntax)

if (s.type is not None and
self.options.disallow_any_unimported and
Expand Down
27 changes: 16 additions & 11 deletions mypy/exprtotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@ def _extract_argument_name(expr: Expression) -> Optional[str]:

def expr_to_unanalyzed_type(expr: Expression,
options: Optional[Options] = None,
allow_new_syntax: bool = False,
_parent: Optional[Expression] = None) -> ProperType:
"""Translate an expression to the corresponding type.
The result is not semantically analyzed. It can be UnboundType or TypeList.
Raise TypeTranslationError if the expression cannot represent a type.
If allow_new_syntax is True, allow all type syntax independent of the target
Python version (used in stubs).
"""
# The `parent` parameter is used in recursive calls to provide context for
# understanding whether an CallableArgument is ok.
Expand All @@ -56,7 +60,7 @@ def expr_to_unanalyzed_type(expr: Expression,
else:
raise TypeTranslationError()
elif isinstance(expr, IndexExpr):
base = expr_to_unanalyzed_type(expr.base, options, expr)
base = expr_to_unanalyzed_type(expr.base, options, allow_new_syntax, expr)
if isinstance(base, UnboundType):
if base.args:
raise TypeTranslationError()
Expand All @@ -72,20 +76,20 @@ def expr_to_unanalyzed_type(expr: Expression,
# of the Annotation definition and only returning the type information,
# losing all the annotations.

return expr_to_unanalyzed_type(args[0], options, expr)
return expr_to_unanalyzed_type(args[0], options, allow_new_syntax, expr)
else:
base.args = tuple(expr_to_unanalyzed_type(arg, options, expr) for arg in args)
base.args = tuple(expr_to_unanalyzed_type(arg, options, allow_new_syntax, expr)
for arg in args)
if not base.args:
base.empty_tuple_index = True
return base
else:
raise TypeTranslationError()
elif (isinstance(expr, OpExpr)
and expr.op == '|'
and options
and options.python_version >= (3, 10)):
return UnionType([expr_to_unanalyzed_type(expr.left, options),
expr_to_unanalyzed_type(expr.right, options)])
and ((options and options.python_version >= (3, 10)) or allow_new_syntax)):
return UnionType([expr_to_unanalyzed_type(expr.left, options, allow_new_syntax),
expr_to_unanalyzed_type(expr.right, options, allow_new_syntax)])
elif isinstance(expr, CallExpr) and isinstance(_parent, ListExpr):
c = expr.callee
names = []
Expand Down Expand Up @@ -118,19 +122,20 @@ def expr_to_unanalyzed_type(expr: Expression,
if typ is not default_type:
# Two types
raise TypeTranslationError()
typ = expr_to_unanalyzed_type(arg, options, expr)
typ = expr_to_unanalyzed_type(arg, options, allow_new_syntax, expr)
continue
else:
raise TypeTranslationError()
elif i == 0:
typ = expr_to_unanalyzed_type(arg, options, expr)
typ = expr_to_unanalyzed_type(arg, options, allow_new_syntax, expr)
elif i == 1:
name = _extract_argument_name(arg)
else:
raise TypeTranslationError()
return CallableArgument(typ, name, arg_const, expr.line, expr.column)
elif isinstance(expr, ListExpr):
return TypeList([expr_to_unanalyzed_type(t, options, expr) for t in expr.items],
return TypeList([expr_to_unanalyzed_type(t, options, allow_new_syntax, expr)
for t in expr.items],
line=expr.line, column=expr.column)
elif isinstance(expr, StrExpr):
return parse_type_string(expr.value, 'builtins.str', expr.line, expr.column,
Expand All @@ -142,7 +147,7 @@ def expr_to_unanalyzed_type(expr: Expression,
return parse_type_string(expr.value, 'builtins.unicode', expr.line, expr.column,
assume_str_is_unicode=True)
elif isinstance(expr, UnaryExpr):
typ = expr_to_unanalyzed_type(expr.expr, options)
typ = expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax)
if isinstance(typ, RawExpressionType):
if isinstance(typ.literal_value, int) and expr.op == '-':
typ.literal_value *= -1
Expand Down
5 changes: 5 additions & 0 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,11 @@ def final_iteration(self) -> bool:
"""Is this the final iteration of semantic analysis?"""
raise NotImplementedError

@property
@abstractmethod
def is_stub_file(self) -> bool:
raise NotImplementedError


# A context for querying for configuration data about a module for
# cache invalidation purposes.
Expand Down
2 changes: 1 addition & 1 deletion mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext',
type_arg = _get_argument(rvalue, 'type')
if type_arg and not init_type:
try:
un_type = expr_to_unanalyzed_type(type_arg, ctx.api.options)
un_type = expr_to_unanalyzed_type(type_arg, ctx.api.options, ctx.api.is_stub_file)
except TypeTranslationError:
ctx.api.fail('Invalid argument to type', type_arg)
else:
Expand Down
23 changes: 13 additions & 10 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,7 +1267,7 @@ class Foo(Bar, Generic[T]): ...
self.analyze_type_expr(base_expr)

try:
base = expr_to_unanalyzed_type(base_expr, self.options)
base = self.expr_to_unanalyzed_type(base_expr)
except TypeTranslationError:
# This error will be caught later.
continue
Expand Down Expand Up @@ -1373,7 +1373,7 @@ def get_all_bases_tvars(self,
for i, base_expr in enumerate(base_type_exprs):
if i not in removed:
try:
base = expr_to_unanalyzed_type(base_expr, self.options)
base = self.expr_to_unanalyzed_type(base_expr)
except TypeTranslationError:
# This error will be caught later.
continue
Expand Down Expand Up @@ -2507,7 +2507,7 @@ def analyze_alias(self, rvalue: Expression,
self.plugin,
self.options,
self.is_typeshed_stub_file,
allow_unnormalized=self.is_stub_file,
allow_new_syntax=self.is_stub_file,
allow_placeholder=allow_placeholder,
in_dynamic_func=dynamic,
global_scope=global_scope)
Expand Down Expand Up @@ -3202,7 +3202,7 @@ def analyze_value_types(self, items: List[Expression]) -> List[Type]:
result: List[Type] = []
for node in items:
try:
analyzed = self.anal_type(expr_to_unanalyzed_type(node, self.options),
analyzed = self.anal_type(self.expr_to_unanalyzed_type(node),
allow_placeholder=True)
if analyzed is None:
# Type variables are special: we need to place them in the symbol table
Expand Down Expand Up @@ -3645,7 +3645,7 @@ def visit_call_expr(self, expr: CallExpr) -> None:
return
# Translate first argument to an unanalyzed type.
try:
target = expr_to_unanalyzed_type(expr.args[0], self.options)
target = self.expr_to_unanalyzed_type(expr.args[0])
except TypeTranslationError:
self.fail('Cast target is not a type', expr)
return
Expand Down Expand Up @@ -3703,7 +3703,7 @@ def visit_call_expr(self, expr: CallExpr) -> None:
return
# Translate first argument to an unanalyzed type.
try:
target = expr_to_unanalyzed_type(expr.args[0], self.options)
target = self.expr_to_unanalyzed_type(expr.args[0])
except TypeTranslationError:
self.fail('Argument 1 to _promote is not a type', expr)
return
Expand Down Expand Up @@ -3899,7 +3899,7 @@ def analyze_type_application_args(self, expr: IndexExpr) -> Optional[List[Type]]
items = [index]
for item in items:
try:
typearg = expr_to_unanalyzed_type(item, self.options)
typearg = self.expr_to_unanalyzed_type(item)
except TypeTranslationError:
self.fail('Type expected within [...]', expr)
return None
Expand Down Expand Up @@ -4206,7 +4206,7 @@ def lookup_qualified(self, name: str, ctx: Context,

def lookup_type_node(self, expr: Expression) -> Optional[SymbolTableNode]:
try:
t = expr_to_unanalyzed_type(expr, self.options)
t = self.expr_to_unanalyzed_type(expr)
except TypeTranslationError:
return None
if isinstance(t, UnboundType):
Expand Down Expand Up @@ -4926,7 +4926,7 @@ def expr_to_analyzed_type(self,
assert info.tuple_type, "NamedTuple without tuple type"
fallback = Instance(info, [])
return TupleType(info.tuple_type.items, fallback=fallback)
typ = expr_to_unanalyzed_type(expr, self.options)
typ = self.expr_to_unanalyzed_type(expr)
return self.anal_type(typ, report_invalid_types=report_invalid_types,
allow_placeholder=allow_placeholder)

Expand Down Expand Up @@ -4956,12 +4956,15 @@ def type_analyzer(self, *,
allow_unbound_tvars=allow_unbound_tvars,
allow_tuple_literal=allow_tuple_literal,
report_invalid_types=report_invalid_types,
allow_unnormalized=self.is_stub_file,
allow_new_syntax=self.is_stub_file,
allow_placeholder=allow_placeholder)
tpan.in_dynamic_func = bool(self.function_stack and self.function_stack[-1].is_dynamic())
tpan.global_scope = not self.type and not self.function_stack
return tpan

def expr_to_unanalyzed_type(self, node: Expression) -> ProperType:
return expr_to_unanalyzed_type(node, self.options, self.is_stub_file)

def anal_type(self,
typ: Type, *,
tvar_scope: Optional[TypeVarLikeScope] = None,
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal_namedtuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def parse_namedtuple_fields_with_types(self, nodes: List[Expression], context: C
self.fail("Invalid NamedTuple() field name", item)
return None
try:
type = expr_to_unanalyzed_type(type_node, self.options)
type = expr_to_unanalyzed_type(type_node, self.options, self.api.is_stub_file)
except TypeTranslationError:
self.fail('Invalid field type', type_node)
return None
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal_newtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def check_newtype_args(self, name: str, call: CallExpr,
# Check second argument
msg = "Argument 2 to NewType(...) must be a valid type"
try:
unanalyzed_type = expr_to_unanalyzed_type(args[1], self.options)
unanalyzed_type = expr_to_unanalyzed_type(args[1], self.options, self.api.is_stub_file)
except TypeTranslationError:
self.fail(msg, context)
return None, False
Expand Down
3 changes: 2 additions & 1 deletion mypy/semanal_typeddict.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ def parse_typeddict_fields_with_types(
self.fail_typeddict_arg("Invalid TypedDict() field name", name_context)
return [], [], False
try:
type = expr_to_unanalyzed_type(field_type_expr, self.options)
type = expr_to_unanalyzed_type(field_type_expr, self.options,
self.api.is_stub_file)
except TypeTranslationError:
self.fail_typeddict_arg('Invalid field type', field_type_expr)
return [], [], False
Expand Down
14 changes: 7 additions & 7 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def analyze_type_alias(node: Expression,
plugin: Plugin,
options: Options,
is_typeshed_stub: bool,
allow_unnormalized: bool = False,
allow_new_syntax: bool = False,
allow_placeholder: bool = False,
in_dynamic_func: bool = False,
global_scope: bool = True) -> Optional[Tuple[Type, Set[str]]]:
Expand All @@ -80,12 +80,12 @@ def analyze_type_alias(node: Expression,
Return None otherwise. 'node' must have been semantically analyzed.
"""
try:
type = expr_to_unanalyzed_type(node, options)
type = expr_to_unanalyzed_type(node, options, allow_new_syntax)
except TypeTranslationError:
api.fail('Invalid type alias: expression is not a valid type', node)
return None
analyzer = TypeAnalyser(api, tvar_scope, plugin, options, is_typeshed_stub,
allow_unnormalized=allow_unnormalized, defining_alias=True,
allow_new_syntax=allow_new_syntax, defining_alias=True,
allow_placeholder=allow_placeholder)
analyzer.in_dynamic_func = in_dynamic_func
analyzer.global_scope = global_scope
Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(self,
is_typeshed_stub: bool, *,
defining_alias: bool = False,
allow_tuple_literal: bool = False,
allow_unnormalized: bool = False,
allow_new_syntax: bool = False,
allow_unbound_tvars: bool = False,
allow_placeholder: bool = False,
report_invalid_types: bool = True) -> None:
Expand All @@ -143,7 +143,7 @@ def __init__(self,
self.nesting_level = 0
# Should we allow unnormalized types like `list[int]`
# (currently allowed in stubs)?
self.allow_unnormalized = allow_unnormalized
self.allow_new_syntax = allow_new_syntax
# Should we accept unbound type variables (always OK in aliases)?
self.allow_unbound_tvars = allow_unbound_tvars or defining_alias
# If false, record incomplete ref if we generate PlaceholderType.
Expand Down Expand Up @@ -199,7 +199,7 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool)
return hook(AnalyzeTypeContext(t, t, self))
if (fullname in get_nongen_builtins(self.options.python_version)
and t.args and
not self.allow_unnormalized and
not self.allow_new_syntax and
not self.api.is_future_flag_set("annotations")):
self.fail(no_subscript_builtin_alias(fullname,
propose_alt=not self.defining_alias), t)
Expand Down Expand Up @@ -282,7 +282,7 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt
elif (fullname == 'typing.Tuple' or
(fullname == 'builtins.tuple' and (self.options.python_version >= (3, 9) or
self.api.is_future_flag_set('annotations') or
self.allow_unnormalized))):
self.allow_new_syntax))):
# Tuple is special because it is involved in builtin import cycle
# and may be not ready when used.
sym = self.api.lookup_fully_qualified_or_none('builtins.tuple')
Expand Down
11 changes: 11 additions & 0 deletions test-data/unit/check-union-or-syntax.test
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,14 @@ def f() -> object: pass
reveal_type(cast(str | None, f())) # N: Revealed type is "Union[builtins.str, None]"
reveal_type(list[str | None]()) # N: Revealed type is "builtins.list[Union[builtins.str, None]]"
[builtins fixtures/type.pyi]

[case testUnionOrSyntaxRuntimeContextInStubFile]
import lib
reveal_type(lib.x) # N: Revealed type is "Union[builtins.int, builtins.list[builtins.str], None]"

[file lib.pyi]
A = int | list[str] | None
x: A
class C(list[int | None]):
pass
[builtins fixtures/list.pyi]

0 comments on commit dc96406

Please sign in to comment.