From dc964065593a1adbe3368a6c146dc154dcfd5121 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 6 Jul 2021 14:04:04 +0100 Subject: [PATCH] Support new union syntax in stubs always in runtime context Previously it only worked when the target Python version was 3.10. Work on #9880. --- mypy/checker.py | 10 +++------ mypy/exprtotype.py | 27 ++++++++++++++--------- mypy/plugin.py | 5 +++++ mypy/plugins/attrs.py | 2 +- mypy/semanal.py | 23 ++++++++++--------- mypy/semanal_namedtuple.py | 2 +- mypy/semanal_newtype.py | 2 +- mypy/semanal_typeddict.py | 3 ++- mypy/typeanal.py | 14 ++++++------ test-data/unit/check-union-or-syntax.test | 11 +++++++++ 10 files changed, 60 insertions(+), 39 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 81ca24900aa4..7fe6860b4274 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2023,13 +2023,9 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: Handle all kinds of assignment statements (simple, indexed, multiple). """ - with self.enter_final_context(s.is_final_def): - self.check_assignment(s.lvalues[-1], s.rvalue, s.type is None, s.new_syntax) - - if s.is_alias_def: - # We do this mostly for compatibility with old semantic analyzer. - # TODO: should we get rid of this? - self.store_type(s.lvalues[-1], self.expr_checker.accept(s.rvalue)) + if not (s.is_alias_def and self.is_stub): + with self.enter_final_context(s.is_final_def): + self.check_assignment(s.lvalues[-1], s.rvalue, s.type is None, s.new_syntax) if (s.type is not None and self.options.disallow_any_unimported and diff --git a/mypy/exprtotype.py b/mypy/exprtotype.py index 685e26e35b70..8f6f6c11f346 100644 --- a/mypy/exprtotype.py +++ b/mypy/exprtotype.py @@ -32,11 +32,15 @@ def _extract_argument_name(expr: Expression) -> Optional[str]: def expr_to_unanalyzed_type(expr: Expression, options: Optional[Options] = None, + allow_new_syntax: bool = False, _parent: Optional[Expression] = None) -> ProperType: """Translate an expression to the corresponding type. The result is not semantically analyzed. It can be UnboundType or TypeList. Raise TypeTranslationError if the expression cannot represent a type. + + If allow_new_syntax is True, allow all type syntax independent of the target + Python version (used in stubs). """ # The `parent` parameter is used in recursive calls to provide context for # understanding whether an CallableArgument is ok. @@ -56,7 +60,7 @@ def expr_to_unanalyzed_type(expr: Expression, else: raise TypeTranslationError() elif isinstance(expr, IndexExpr): - base = expr_to_unanalyzed_type(expr.base, options, expr) + base = expr_to_unanalyzed_type(expr.base, options, allow_new_syntax, expr) if isinstance(base, UnboundType): if base.args: raise TypeTranslationError() @@ -72,9 +76,10 @@ def expr_to_unanalyzed_type(expr: Expression, # of the Annotation definition and only returning the type information, # losing all the annotations. - return expr_to_unanalyzed_type(args[0], options, expr) + return expr_to_unanalyzed_type(args[0], options, allow_new_syntax, expr) else: - base.args = tuple(expr_to_unanalyzed_type(arg, options, expr) for arg in args) + base.args = tuple(expr_to_unanalyzed_type(arg, options, allow_new_syntax, expr) + for arg in args) if not base.args: base.empty_tuple_index = True return base @@ -82,10 +87,9 @@ def expr_to_unanalyzed_type(expr: Expression, raise TypeTranslationError() elif (isinstance(expr, OpExpr) and expr.op == '|' - and options - and options.python_version >= (3, 10)): - return UnionType([expr_to_unanalyzed_type(expr.left, options), - expr_to_unanalyzed_type(expr.right, options)]) + and ((options and options.python_version >= (3, 10)) or allow_new_syntax)): + return UnionType([expr_to_unanalyzed_type(expr.left, options, allow_new_syntax), + expr_to_unanalyzed_type(expr.right, options, allow_new_syntax)]) elif isinstance(expr, CallExpr) and isinstance(_parent, ListExpr): c = expr.callee names = [] @@ -118,19 +122,20 @@ def expr_to_unanalyzed_type(expr: Expression, if typ is not default_type: # Two types raise TypeTranslationError() - typ = expr_to_unanalyzed_type(arg, options, expr) + typ = expr_to_unanalyzed_type(arg, options, allow_new_syntax, expr) continue else: raise TypeTranslationError() elif i == 0: - typ = expr_to_unanalyzed_type(arg, options, expr) + typ = expr_to_unanalyzed_type(arg, options, allow_new_syntax, expr) elif i == 1: name = _extract_argument_name(arg) else: raise TypeTranslationError() return CallableArgument(typ, name, arg_const, expr.line, expr.column) elif isinstance(expr, ListExpr): - return TypeList([expr_to_unanalyzed_type(t, options, expr) for t in expr.items], + return TypeList([expr_to_unanalyzed_type(t, options, allow_new_syntax, expr) + for t in expr.items], line=expr.line, column=expr.column) elif isinstance(expr, StrExpr): return parse_type_string(expr.value, 'builtins.str', expr.line, expr.column, @@ -142,7 +147,7 @@ def expr_to_unanalyzed_type(expr: Expression, return parse_type_string(expr.value, 'builtins.unicode', expr.line, expr.column, assume_str_is_unicode=True) elif isinstance(expr, UnaryExpr): - typ = expr_to_unanalyzed_type(expr.expr, options) + typ = expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax) if isinstance(typ, RawExpressionType): if isinstance(typ.literal_value, int) and expr.op == '-': typ.literal_value *= -1 diff --git a/mypy/plugin.py b/mypy/plugin.py index 0f38bb32eeea..4efa350cdcba 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -357,6 +357,11 @@ def final_iteration(self) -> bool: """Is this the final iteration of semantic analysis?""" raise NotImplementedError + @property + @abstractmethod + def is_stub_file(self) -> bool: + raise NotImplementedError + # A context for querying for configuration data about a module for # cache invalidation purposes. diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index c786ca7a8ce2..3187f13aeafa 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -552,7 +552,7 @@ def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext', type_arg = _get_argument(rvalue, 'type') if type_arg and not init_type: try: - un_type = expr_to_unanalyzed_type(type_arg, ctx.api.options) + un_type = expr_to_unanalyzed_type(type_arg, ctx.api.options, ctx.api.is_stub_file) except TypeTranslationError: ctx.api.fail('Invalid argument to type', type_arg) else: diff --git a/mypy/semanal.py b/mypy/semanal.py index 8f479ec93304..3cfb14ea447e 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1267,7 +1267,7 @@ class Foo(Bar, Generic[T]): ... self.analyze_type_expr(base_expr) try: - base = expr_to_unanalyzed_type(base_expr, self.options) + base = self.expr_to_unanalyzed_type(base_expr) except TypeTranslationError: # This error will be caught later. continue @@ -1373,7 +1373,7 @@ def get_all_bases_tvars(self, for i, base_expr in enumerate(base_type_exprs): if i not in removed: try: - base = expr_to_unanalyzed_type(base_expr, self.options) + base = self.expr_to_unanalyzed_type(base_expr) except TypeTranslationError: # This error will be caught later. continue @@ -2507,7 +2507,7 @@ def analyze_alias(self, rvalue: Expression, self.plugin, self.options, self.is_typeshed_stub_file, - allow_unnormalized=self.is_stub_file, + allow_new_syntax=self.is_stub_file, allow_placeholder=allow_placeholder, in_dynamic_func=dynamic, global_scope=global_scope) @@ -3202,7 +3202,7 @@ def analyze_value_types(self, items: List[Expression]) -> List[Type]: result: List[Type] = [] for node in items: try: - analyzed = self.anal_type(expr_to_unanalyzed_type(node, self.options), + analyzed = self.anal_type(self.expr_to_unanalyzed_type(node), allow_placeholder=True) if analyzed is None: # Type variables are special: we need to place them in the symbol table @@ -3645,7 +3645,7 @@ def visit_call_expr(self, expr: CallExpr) -> None: return # Translate first argument to an unanalyzed type. try: - target = expr_to_unanalyzed_type(expr.args[0], self.options) + target = self.expr_to_unanalyzed_type(expr.args[0]) except TypeTranslationError: self.fail('Cast target is not a type', expr) return @@ -3703,7 +3703,7 @@ def visit_call_expr(self, expr: CallExpr) -> None: return # Translate first argument to an unanalyzed type. try: - target = expr_to_unanalyzed_type(expr.args[0], self.options) + target = self.expr_to_unanalyzed_type(expr.args[0]) except TypeTranslationError: self.fail('Argument 1 to _promote is not a type', expr) return @@ -3899,7 +3899,7 @@ def analyze_type_application_args(self, expr: IndexExpr) -> Optional[List[Type]] items = [index] for item in items: try: - typearg = expr_to_unanalyzed_type(item, self.options) + typearg = self.expr_to_unanalyzed_type(item) except TypeTranslationError: self.fail('Type expected within [...]', expr) return None @@ -4206,7 +4206,7 @@ def lookup_qualified(self, name: str, ctx: Context, def lookup_type_node(self, expr: Expression) -> Optional[SymbolTableNode]: try: - t = expr_to_unanalyzed_type(expr, self.options) + t = self.expr_to_unanalyzed_type(expr) except TypeTranslationError: return None if isinstance(t, UnboundType): @@ -4926,7 +4926,7 @@ def expr_to_analyzed_type(self, assert info.tuple_type, "NamedTuple without tuple type" fallback = Instance(info, []) return TupleType(info.tuple_type.items, fallback=fallback) - typ = expr_to_unanalyzed_type(expr, self.options) + typ = self.expr_to_unanalyzed_type(expr) return self.anal_type(typ, report_invalid_types=report_invalid_types, allow_placeholder=allow_placeholder) @@ -4956,12 +4956,15 @@ def type_analyzer(self, *, allow_unbound_tvars=allow_unbound_tvars, allow_tuple_literal=allow_tuple_literal, report_invalid_types=report_invalid_types, - allow_unnormalized=self.is_stub_file, + allow_new_syntax=self.is_stub_file, allow_placeholder=allow_placeholder) tpan.in_dynamic_func = bool(self.function_stack and self.function_stack[-1].is_dynamic()) tpan.global_scope = not self.type and not self.function_stack return tpan + def expr_to_unanalyzed_type(self, node: Expression) -> ProperType: + return expr_to_unanalyzed_type(node, self.options, self.is_stub_file) + def anal_type(self, typ: Type, *, tvar_scope: Optional[TypeVarLikeScope] = None, diff --git a/mypy/semanal_namedtuple.py b/mypy/semanal_namedtuple.py index 2382ae633d93..ca2e31d627e5 100644 --- a/mypy/semanal_namedtuple.py +++ b/mypy/semanal_namedtuple.py @@ -356,7 +356,7 @@ def parse_namedtuple_fields_with_types(self, nodes: List[Expression], context: C self.fail("Invalid NamedTuple() field name", item) return None try: - type = expr_to_unanalyzed_type(type_node, self.options) + type = expr_to_unanalyzed_type(type_node, self.options, self.api.is_stub_file) except TypeTranslationError: self.fail('Invalid field type', type_node) return None diff --git a/mypy/semanal_newtype.py b/mypy/semanal_newtype.py index 0360cbb86dab..4d5077dbfe43 100644 --- a/mypy/semanal_newtype.py +++ b/mypy/semanal_newtype.py @@ -160,7 +160,7 @@ def check_newtype_args(self, name: str, call: CallExpr, # Check second argument msg = "Argument 2 to NewType(...) must be a valid type" try: - unanalyzed_type = expr_to_unanalyzed_type(args[1], self.options) + unanalyzed_type = expr_to_unanalyzed_type(args[1], self.options, self.api.is_stub_file) except TypeTranslationError: self.fail(msg, context) return None, False diff --git a/mypy/semanal_typeddict.py b/mypy/semanal_typeddict.py index 3ab4e6d698d5..f70bbe427124 100644 --- a/mypy/semanal_typeddict.py +++ b/mypy/semanal_typeddict.py @@ -290,7 +290,8 @@ def parse_typeddict_fields_with_types( self.fail_typeddict_arg("Invalid TypedDict() field name", name_context) return [], [], False try: - type = expr_to_unanalyzed_type(field_type_expr, self.options) + type = expr_to_unanalyzed_type(field_type_expr, self.options, + self.api.is_stub_file) except TypeTranslationError: self.fail_typeddict_arg('Invalid field type', field_type_expr) return [], [], False diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 45a9960f6dc2..5e38607539fb 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -69,7 +69,7 @@ def analyze_type_alias(node: Expression, plugin: Plugin, options: Options, is_typeshed_stub: bool, - allow_unnormalized: bool = False, + allow_new_syntax: bool = False, allow_placeholder: bool = False, in_dynamic_func: bool = False, global_scope: bool = True) -> Optional[Tuple[Type, Set[str]]]: @@ -80,12 +80,12 @@ def analyze_type_alias(node: Expression, Return None otherwise. 'node' must have been semantically analyzed. """ try: - type = expr_to_unanalyzed_type(node, options) + type = expr_to_unanalyzed_type(node, options, allow_new_syntax) except TypeTranslationError: api.fail('Invalid type alias: expression is not a valid type', node) return None analyzer = TypeAnalyser(api, tvar_scope, plugin, options, is_typeshed_stub, - allow_unnormalized=allow_unnormalized, defining_alias=True, + allow_new_syntax=allow_new_syntax, defining_alias=True, allow_placeholder=allow_placeholder) analyzer.in_dynamic_func = in_dynamic_func analyzer.global_scope = global_scope @@ -126,7 +126,7 @@ def __init__(self, is_typeshed_stub: bool, *, defining_alias: bool = False, allow_tuple_literal: bool = False, - allow_unnormalized: bool = False, + allow_new_syntax: bool = False, allow_unbound_tvars: bool = False, allow_placeholder: bool = False, report_invalid_types: bool = True) -> None: @@ -143,7 +143,7 @@ def __init__(self, self.nesting_level = 0 # Should we allow unnormalized types like `list[int]` # (currently allowed in stubs)? - self.allow_unnormalized = allow_unnormalized + self.allow_new_syntax = allow_new_syntax # Should we accept unbound type variables (always OK in aliases)? self.allow_unbound_tvars = allow_unbound_tvars or defining_alias # If false, record incomplete ref if we generate PlaceholderType. @@ -199,7 +199,7 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) return hook(AnalyzeTypeContext(t, t, self)) if (fullname in get_nongen_builtins(self.options.python_version) and t.args and - not self.allow_unnormalized and + not self.allow_new_syntax and not self.api.is_future_flag_set("annotations")): self.fail(no_subscript_builtin_alias(fullname, propose_alt=not self.defining_alias), t) @@ -282,7 +282,7 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt elif (fullname == 'typing.Tuple' or (fullname == 'builtins.tuple' and (self.options.python_version >= (3, 9) or self.api.is_future_flag_set('annotations') or - self.allow_unnormalized))): + self.allow_new_syntax))): # Tuple is special because it is involved in builtin import cycle # and may be not ready when used. sym = self.api.lookup_fully_qualified_or_none('builtins.tuple') diff --git a/test-data/unit/check-union-or-syntax.test b/test-data/unit/check-union-or-syntax.test index 07edc6bd8d6c..b9c6d5ca6d32 100644 --- a/test-data/unit/check-union-or-syntax.test +++ b/test-data/unit/check-union-or-syntax.test @@ -155,3 +155,14 @@ def f() -> object: pass reveal_type(cast(str | None, f())) # N: Revealed type is "Union[builtins.str, None]" reveal_type(list[str | None]()) # N: Revealed type is "builtins.list[Union[builtins.str, None]]" [builtins fixtures/type.pyi] + +[case testUnionOrSyntaxRuntimeContextInStubFile] +import lib +reveal_type(lib.x) # N: Revealed type is "Union[builtins.int, builtins.list[builtins.str], None]" + +[file lib.pyi] +A = int | list[str] | None +x: A +class C(list[int | None]): + pass +[builtins fixtures/list.pyi]