From 83695e0e6b1c997f2dd4e6fe4d1287c8b8da4c3e Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Tue, 20 Aug 2024 11:35:05 -0400 Subject: [PATCH 1/3] Use equality when relating formal and expected type in arg checking --- .../rustc_hir_typeck/src/fn_ctxt/checks.rs | 9 ++++----- .../coercion/constrain-expectation-in-arg.rs | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) create mode 100644 tests/ui/coercion/constrain-expectation-in-arg.rs diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs index aca29d4758708..3e2cec7260061 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs @@ -291,21 +291,20 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let coerce_error = self.coerce(provided_arg, checked_ty, coerced_ty, AllowTwoPhase::Yes, None).err(); - if coerce_error.is_some() { return Compatibility::Incompatible(coerce_error); } - // 3. Check if the formal type is a supertype of the checked one - // and register any such obligations for future type checks - let supertype_error = self.at(&self.misc(provided_arg.span), self.param_env).sup( + // 3. Check if the formal type is actually equal to the checked one + // and register any such obligations for future type checks. + let formal_ty_error = self.at(&self.misc(provided_arg.span), self.param_env).eq( DefineOpaqueTypes::Yes, formal_input_ty, coerced_ty, ); // If neither check failed, the types are compatible - match supertype_error { + match formal_ty_error { Ok(InferOk { obligations, value: () }) => { self.register_predicates(obligations); Compatibility::Compatible diff --git a/tests/ui/coercion/constrain-expectation-in-arg.rs b/tests/ui/coercion/constrain-expectation-in-arg.rs new file mode 100644 index 0000000000000..858c3a0bdb572 --- /dev/null +++ b/tests/ui/coercion/constrain-expectation-in-arg.rs @@ -0,0 +1,19 @@ +//@ check-pass + +trait Trait { + type Item; +} + +struct Struct, B> { + pub field: A, +} + +fn identity(x: T) -> T { + x +} + +fn test, B>(x: &Struct) { + let x: &Struct<_, _> = identity(x); +} + +fn main() {} From 7246d31127412461b8e73fee60ed5d466ff36a1c Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Sun, 25 Aug 2024 11:40:47 -0400 Subject: [PATCH 2/3] Inline expected_inputs_for_expected_output into check_argument_types --- compiler/rustc_hir_typeck/src/callee.rs | 21 ++----- .../rustc_hir_typeck/src/fn_ctxt/_impl.rs | 36 ------------ .../rustc_hir_typeck/src/fn_ctxt/checks.rs | 57 ++++++++++++++----- 3 files changed, 48 insertions(+), 66 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/callee.rs b/compiler/rustc_hir_typeck/src/callee.rs index 44cb08e44eb81..a073c1d97cf5d 100644 --- a/compiler/rustc_hir_typeck/src/callee.rs +++ b/compiler/rustc_hir_typeck/src/callee.rs @@ -502,18 +502,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let fn_sig = self.instantiate_binder_with_fresh_vars(call_expr.span, infer::FnCall, fn_sig); let fn_sig = self.normalize(call_expr.span, fn_sig); - // Call the generic checker. - let expected_arg_tys = self.expected_inputs_for_expected_output( - call_expr.span, - expected, - fn_sig.output(), - fn_sig.inputs(), - ); self.check_argument_types( call_expr.span, call_expr, fn_sig.inputs(), - expected_arg_tys, + fn_sig.output(), + expected, arg_exprs, fn_sig.c_variadic, TupleArgumentsFlag::DontTupleArguments, @@ -865,19 +859,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { // don't know the full details yet (`Fn` vs `FnMut` etc), but we // do know the types expected for each argument and the return // type. - - let expected_arg_tys = self.expected_inputs_for_expected_output( - call_expr.span, - expected, - fn_sig.output(), - fn_sig.inputs(), - ); - self.check_argument_types( call_expr.span, call_expr, fn_sig.inputs(), - expected_arg_tys, + fn_sig.output(), + expected, arg_exprs, fn_sig.c_variadic, TupleArgumentsFlag::TupleArguments, diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs index b169f75796b3a..c0bb3434982fd 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs @@ -688,42 +688,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { vec![ty_error; len] } - /// Unifies the output type with the expected type early, for more coercions - /// and forward type information on the input expressions. - #[instrument(skip(self, call_span), level = "debug")] - pub(crate) fn expected_inputs_for_expected_output( - &self, - call_span: Span, - expected_ret: Expectation<'tcx>, - formal_ret: Ty<'tcx>, - formal_args: &[Ty<'tcx>], - ) -> Option>> { - let formal_ret = self.resolve_vars_with_obligations(formal_ret); - let ret_ty = expected_ret.only_has_type(self)?; - - let expect_args = self - .fudge_inference_if_ok(|| { - let ocx = ObligationCtxt::new(self); - - // Attempt to apply a subtyping relationship between the formal - // return type (likely containing type variables if the function - // is polymorphic) and the expected return type. - // No argument expectations are produced if unification fails. - let origin = self.misc(call_span); - ocx.sup(&origin, self.param_env, ret_ty, formal_ret)?; - if !ocx.select_where_possible().is_empty() { - return Err(TypeError::Mismatch); - } - - // Record all the argument types, with the args - // produced from the above subtyping unification. - Ok(Some(formal_args.iter().map(|&ty| self.resolve_vars_if_possible(ty)).collect())) - }) - .unwrap_or_default(); - debug!(?formal_args, ?formal_ret, ?expect_args, ?expected_ret); - expect_args - } - pub(crate) fn resolve_lang_item_path( &self, lang_item: hir::LangItem, diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs index 3e2cec7260061..b9a07e2ff08ca 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs @@ -17,6 +17,7 @@ use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer; use rustc_index::IndexVec; use rustc_infer::infer::{DefineOpaqueTypes, InferOk, TypeTrace}; use rustc_middle::ty::adjustment::AllowTwoPhase; +use rustc_middle::ty::error::TypeError; use rustc_middle::ty::visit::TypeVisitableExt; use rustc_middle::ty::{self, IsSuggestable, Ty, TyCtxt}; use rustc_middle::{bug, span_bug}; @@ -25,7 +26,7 @@ use rustc_span::symbol::{kw, Ident}; use rustc_span::{sym, Span, DUMMY_SP}; use rustc_trait_selection::error_reporting::infer::{FailureCode, ObligationCauseExt}; use rustc_trait_selection::infer::InferCtxtExt; -use rustc_trait_selection::traits::{self, ObligationCauseCode, SelectionContext}; +use rustc_trait_selection::traits::{self, ObligationCauseCode, ObligationCtxt, SelectionContext}; use {rustc_ast as ast, rustc_hir as hir}; use crate::coercion::CoerceMany; @@ -123,6 +124,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { }; if let Err(guar) = has_error { let err_inputs = self.err_args(args_no_rcvr.len(), guar); + let err_output = Ty::new_error(self.tcx, guar); let err_inputs = match tuple_arguments { DontTupleArguments => err_inputs, @@ -133,28 +135,23 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { sp, expr, &err_inputs, - None, + err_output, + NoExpectation, args_no_rcvr, false, tuple_arguments, method.ok().map(|method| method.def_id), ); - return Ty::new_error(self.tcx, guar); + return err_output; } let method = method.unwrap(); - // HACK(eddyb) ignore self in the definition (see above). - let expected_input_tys = self.expected_inputs_for_expected_output( - sp, - expected, - method.sig.output(), - &method.sig.inputs()[1..], - ); self.check_argument_types( sp, expr, &method.sig.inputs()[1..], - expected_input_tys, + method.sig.output(), + expected, args_no_rcvr, method.sig.c_variadic, tuple_arguments, @@ -174,8 +171,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { call_expr: &'tcx hir::Expr<'tcx>, // Types (as defined in the *signature* of the target function) formal_input_tys: &[Ty<'tcx>], - // More specific expected types, after unifying with caller output types - expected_input_tys: Option>>, + formal_output: Ty<'tcx>, + // Expected output from the parent expression or statement + expectation: Expectation<'tcx>, // The expressions for each provided argument provided_args: &'tcx [hir::Expr<'tcx>], // Whether the function is variadic, for example when imported from C @@ -209,6 +207,39 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { ); } + // First, let's unify the formal method signature with the expectation eagerly. + // We use this to guide coercion inference; it's output is "fudged" which means + // any remaining type variables are assigned to new, unrelated variables. This + // is because the inference guidance here is only speculative. + let expected_input_tys: Option> = expectation + .only_has_type(self) + .and_then(|expected_output| { + self.fudge_inference_if_ok(|| { + let ocx = ObligationCtxt::new(self); + + // Attempt to apply a subtyping relationship between the formal + // return type (likely containing type variables if the function + // is polymorphic) and the expected return type. + // No argument expectations are produced if unification fails. + let origin = self.misc(call_span); + ocx.sup(&origin, self.param_env, expected_output, formal_output)?; + if !ocx.select_where_possible().is_empty() { + return Err(TypeError::Mismatch); + } + + // Record all the argument types, with the args + // produced from the above subtyping unification. + Ok(Some( + formal_input_tys + .iter() + .map(|&ty| self.resolve_vars_if_possible(ty)) + .collect(), + )) + }) + .ok() + }) + .unwrap_or_default(); + let mut err_code = E0061; // If the arguments should be wrapped in a tuple (ex: closures), unwrap them here From c2863260b31011c67dfebc6acdb11b7b18591e02 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Sun, 25 Aug 2024 11:43:43 -0400 Subject: [PATCH 3/3] Don't do expected type fudging if the inputs don't need inference guidance --- .../rustc_hir_typeck/src/fn_ctxt/checks.rs | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs index b9a07e2ff08ca..e99eb8d289b5b 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs @@ -211,34 +211,41 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { // We use this to guide coercion inference; it's output is "fudged" which means // any remaining type variables are assigned to new, unrelated variables. This // is because the inference guidance here is only speculative. - let expected_input_tys: Option> = expectation - .only_has_type(self) - .and_then(|expected_output| { - self.fudge_inference_if_ok(|| { - let ocx = ObligationCtxt::new(self); - - // Attempt to apply a subtyping relationship between the formal - // return type (likely containing type variables if the function - // is polymorphic) and the expected return type. - // No argument expectations are produced if unification fails. - let origin = self.misc(call_span); - ocx.sup(&origin, self.param_env, expected_output, formal_output)?; - if !ocx.select_where_possible().is_empty() { - return Err(TypeError::Mismatch); - } + // + // We only do this if the formals have non-region infer vars, since this is only + // meant to guide inference. + let expected_input_tys: Option> = if formal_input_tys.has_non_region_infer() { + expectation + .only_has_type(self) + .and_then(|expected_output| { + self.fudge_inference_if_ok(|| { + let ocx = ObligationCtxt::new(self); + + // Attempt to apply a subtyping relationship between the formal + // return type (likely containing type variables if the function + // is polymorphic) and the expected return type. + // No argument expectations are produced if unification fails. + let origin = self.misc(call_span); + ocx.sup(&origin, self.param_env, expected_output, formal_output)?; + if !ocx.select_where_possible().is_empty() { + return Err(TypeError::Mismatch); + } - // Record all the argument types, with the args - // produced from the above subtyping unification. - Ok(Some( - formal_input_tys - .iter() - .map(|&ty| self.resolve_vars_if_possible(ty)) - .collect(), - )) + // Record all the argument types, with the args + // produced from the above subtyping unification. + Ok(Some( + formal_input_tys + .iter() + .map(|&ty| self.resolve_vars_if_possible(ty)) + .collect(), + )) + }) + .ok() }) - .ok() - }) - .unwrap_or_default(); + .unwrap_or_default() + } else { + None + }; let mut err_code = E0061;