diff --git a/compiler/rustc_const_eval/src/transform/promote_consts.rs b/compiler/rustc_const_eval/src/transform/promote_consts.rs index 8b2ea2dc21dd1..155cf4ff9e2c7 100644 --- a/compiler/rustc_const_eval/src/transform/promote_consts.rs +++ b/compiler/rustc_const_eval/src/transform/promote_consts.rs @@ -969,7 +969,7 @@ pub fn promote_candidates<'tcx>( 0, vec![], body.span, - body.coroutine_kind(), + None, body.tainted_by_errors, ); promoted.phase = MirPhase::Analysis(AnalysisPhase::Initial); diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs index 45dbfe6b8a7d7..d426f6d8969ac 100644 --- a/compiler/rustc_middle/src/mir/mod.rs +++ b/compiler/rustc_middle/src/mir/mod.rs @@ -263,6 +263,23 @@ pub struct CoroutineInfo<'tcx> { pub coroutine_kind: CoroutineKind, } +impl<'tcx> CoroutineInfo<'tcx> { + // Sets up `CoroutineInfo` for a pre-coroutine-transform MIR body. + pub fn initial( + coroutine_kind: CoroutineKind, + yield_ty: Ty<'tcx>, + resume_ty: Ty<'tcx>, + ) -> CoroutineInfo<'tcx> { + CoroutineInfo { + coroutine_kind, + yield_ty: Some(yield_ty), + resume_ty: Some(resume_ty), + coroutine_drop: None, + coroutine_layout: None, + } + } +} + /// The lowered representation of a single function. #[derive(Clone, TyEncodable, TyDecodable, Debug, HashStable, TypeFoldable, TypeVisitable)] pub struct Body<'tcx> { @@ -367,7 +384,7 @@ impl<'tcx> Body<'tcx> { arg_count: usize, var_debug_info: Vec>, span: Span, - coroutine_kind: Option, + coroutine: Option>>, tainted_by_errors: Option, ) -> Self { // We need `arg_count` locals, and one for the return place. @@ -384,15 +401,7 @@ impl<'tcx> Body<'tcx> { source, basic_blocks: BasicBlocks::new(basic_blocks), source_scopes, - coroutine: coroutine_kind.map(|coroutine_kind| { - Box::new(CoroutineInfo { - yield_ty: None, - resume_ty: None, - coroutine_drop: None, - coroutine_layout: None, - coroutine_kind, - }) - }), + coroutine, local_decls, user_type_annotations, arg_count, diff --git a/compiler/rustc_middle/src/thir.rs b/compiler/rustc_middle/src/thir.rs index 2b5983314eecf..b4b8387c262b9 100644 --- a/compiler/rustc_middle/src/thir.rs +++ b/compiler/rustc_middle/src/thir.rs @@ -86,8 +86,6 @@ macro_rules! thir_with_elements { } } -pub const UPVAR_ENV_PARAM: ParamId = ParamId::from_u32(0); - thir_with_elements! { body_type: BodyTy<'tcx>, diff --git a/compiler/rustc_mir_build/src/build/mod.rs b/compiler/rustc_mir_build/src/build/mod.rs index c4cade839478c..b8d08319422d4 100644 --- a/compiler/rustc_mir_build/src/build/mod.rs +++ b/compiler/rustc_mir_build/src/build/mod.rs @@ -9,7 +9,7 @@ use rustc_errors::ErrorGuaranteed; use rustc_hir as hir; use rustc_hir::def::DefKind; use rustc_hir::def_id::{DefId, LocalDefId}; -use rustc_hir::{CoroutineKind, Node}; +use rustc_hir::Node; use rustc_index::bit_set::GrowableBitSet; use rustc_index::{Idx, IndexSlice, IndexVec}; use rustc_infer::infer::{InferCtxt, TyCtxtInferExt}; @@ -177,7 +177,7 @@ struct Builder<'a, 'tcx> { check_overflow: bool, fn_span: Span, arg_count: usize, - coroutine_kind: Option, + coroutine: Option>>, /// The current set of scopes, updated as we traverse; /// see the `scope` module for more details. @@ -458,7 +458,6 @@ fn construct_fn<'tcx>( ) -> Body<'tcx> { let span = tcx.def_span(fn_def); let fn_id = tcx.local_def_id_to_hir_id(fn_def); - let coroutine_kind = tcx.coroutine_kind(fn_def); // The representation of thir for `-Zunpretty=thir-tree` relies on // the entry expression being the last element of `thir.exprs`. @@ -488,17 +487,15 @@ fn construct_fn<'tcx>( let arguments = &thir.params; - let (resume_ty, yield_ty, return_ty) = if coroutine_kind.is_some() { - let coroutine_ty = arguments[thir::UPVAR_ENV_PARAM].ty; - let coroutine_sig = match coroutine_ty.kind() { - ty::Coroutine(_, gen_args, ..) => gen_args.as_coroutine().sig(), - _ => { - span_bug!(span, "coroutine w/o coroutine type: {:?}", coroutine_ty) - } - }; - (Some(coroutine_sig.resume_ty), Some(coroutine_sig.yield_ty), coroutine_sig.return_ty) - } else { - (None, None, fn_sig.output()) + let return_ty = fn_sig.output(); + let coroutine = match tcx.type_of(fn_def).instantiate_identity().kind() { + ty::Coroutine(_, args) => Some(Box::new(CoroutineInfo::initial( + tcx.coroutine_kind(fn_def).unwrap(), + args.as_coroutine().yield_ty(), + args.as_coroutine().resume_ty(), + ))), + ty::Closure(..) | ty::FnDef(..) => None, + ty => span_bug!(span_with_body, "unexpected type of body: {ty:?}"), }; if let Some(custom_mir_attr) = @@ -529,7 +526,7 @@ fn construct_fn<'tcx>( safety, return_ty, return_ty_span, - coroutine_kind, + coroutine, ); let call_site_scope = @@ -563,11 +560,6 @@ fn construct_fn<'tcx>( None }; - if coroutine_kind.is_some() { - body.coroutine.as_mut().unwrap().yield_ty = yield_ty; - body.coroutine.as_mut().unwrap().resume_ty = resume_ty; - } - body } @@ -632,47 +624,62 @@ fn construct_const<'a, 'tcx>( fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -> Body<'_> { let span = tcx.def_span(def_id); let hir_id = tcx.local_def_id_to_hir_id(def_id); - let coroutine_kind = tcx.coroutine_kind(def_id); - let (inputs, output, resume_ty, yield_ty) = match tcx.def_kind(def_id) { + let (inputs, output, coroutine) = match tcx.def_kind(def_id) { DefKind::Const | DefKind::AssocConst | DefKind::AnonConst | DefKind::InlineConst - | DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None, None), + | DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None), DefKind::Ctor(..) | DefKind::Fn | DefKind::AssocFn => { let sig = tcx.liberate_late_bound_regions( def_id.to_def_id(), tcx.fn_sig(def_id).instantiate_identity(), ); - (sig.inputs().to_vec(), sig.output(), None, None) - } - DefKind::Closure if coroutine_kind.is_some() => { - let coroutine_ty = tcx.type_of(def_id).instantiate_identity(); - let ty::Coroutine(_, args) = coroutine_ty.kind() else { - bug!("expected type of coroutine-like closure to be a coroutine") - }; - let args = args.as_coroutine(); - let resume_ty = args.resume_ty(); - let yield_ty = args.yield_ty(); - let return_ty = args.return_ty(); - (vec![coroutine_ty, args.resume_ty()], return_ty, Some(resume_ty), Some(yield_ty)) + (sig.inputs().to_vec(), sig.output(), None) } DefKind::Closure => { let closure_ty = tcx.type_of(def_id).instantiate_identity(); - let ty::Closure(_, args) = closure_ty.kind() else { - bug!("expected type of closure to be a closure") - }; - let args = args.as_closure(); - let sig = tcx.liberate_late_bound_regions(def_id.to_def_id(), args.sig()); - let self_ty = match args.kind() { - ty::ClosureKind::Fn => Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, closure_ty), - ty::ClosureKind::FnMut => Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty), - ty::ClosureKind::FnOnce => closure_ty, - }; - ([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None, None) + match closure_ty.kind() { + ty::Closure(_, args) => { + let args = args.as_closure(); + let sig = tcx.liberate_late_bound_regions(def_id.to_def_id(), args.sig()); + let self_ty = match args.kind() { + ty::ClosureKind::Fn => { + Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, closure_ty) + } + ty::ClosureKind::FnMut => { + Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty) + } + ty::ClosureKind::FnOnce => closure_ty, + }; + ( + [self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), + sig.output(), + None, + ) + } + ty::Coroutine(_, args) => { + let args = args.as_coroutine(); + let resume_ty = args.resume_ty(); + let yield_ty = args.yield_ty(); + let return_ty = args.return_ty(); + ( + vec![closure_ty, args.resume_ty()], + return_ty, + Some(Box::new(CoroutineInfo::initial( + tcx.coroutine_kind(def_id).unwrap(), + yield_ty, + resume_ty, + ))), + ) + } + _ => { + span_bug!(span, "expected type of closure body to be a closure or coroutine"); + } + } } - dk => bug!("{:?} is not a body: {:?}", def_id, dk), + dk => span_bug!(span, "{:?} is not a body: {:?}", def_id, dk), }; let source_info = SourceInfo { span, scope: OUTERMOST_SOURCE_SCOPE }; @@ -696,7 +703,7 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) - cfg.terminate(START_BLOCK, source_info, TerminatorKind::Unreachable); - let mut body = Body::new( + Body::new( MirSource::item(def_id.to_def_id()), cfg.basic_blocks, source_scopes, @@ -705,16 +712,9 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) - inputs.len(), vec![], span, - coroutine_kind, + coroutine, Some(guar), - ); - - body.coroutine.as_mut().map(|gen| { - gen.yield_ty = yield_ty; - gen.resume_ty = resume_ty; - }); - - body + ) } impl<'a, 'tcx> Builder<'a, 'tcx> { @@ -728,7 +728,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { safety: Safety, return_ty: Ty<'tcx>, return_span: Span, - coroutine_kind: Option, + coroutine: Option>>, ) -> Builder<'a, 'tcx> { let tcx = infcx.tcx; let attrs = tcx.hir().attrs(hir_id); @@ -759,7 +759,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { cfg: CFG { basic_blocks: IndexVec::new() }, fn_span: span, arg_count, - coroutine_kind, + coroutine, scopes: scope::Scopes::new(), block_context: BlockContext::new(), source_scopes: IndexVec::new(), @@ -803,7 +803,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { self.arg_count, self.var_debug_info, self.fn_span, - self.coroutine_kind, + self.coroutine, None, ) } diff --git a/compiler/rustc_mir_build/src/build/scope.rs b/compiler/rustc_mir_build/src/build/scope.rs index 1a700ac7342ee..48b237f3ae62a 100644 --- a/compiler/rustc_mir_build/src/build/scope.rs +++ b/compiler/rustc_mir_build/src/build/scope.rs @@ -706,7 +706,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // If we are emitting a `drop` statement, we need to have the cached // diverge cleanup pads ready in case that drop panics. let needs_cleanup = self.scopes.scopes.last().is_some_and(|scope| scope.needs_cleanup()); - let is_coroutine = self.coroutine_kind.is_some(); + let is_coroutine = self.coroutine.is_some(); let unwind_to = if needs_cleanup { self.diverge_cleanup() } else { DropIdx::MAX }; let scope = self.scopes.scopes.last().expect("leave_top_scope called with no scopes"); @@ -960,7 +960,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // path, we only need to invalidate the cache for drops that happen on // the unwind or coroutine drop paths. This means that for // non-coroutines we don't need to invalidate caches for `DropKind::Storage`. - let invalidate_caches = needs_drop || self.coroutine_kind.is_some(); + let invalidate_caches = needs_drop || self.coroutine.is_some(); for scope in self.scopes.scopes.iter_mut().rev() { if invalidate_caches { scope.invalidate_cache(); @@ -1073,7 +1073,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { return cached_drop; } - let is_coroutine = self.coroutine_kind.is_some(); + let is_coroutine = self.coroutine.is_some(); for scope in &mut self.scopes.scopes[uncached_scope..=target] { for drop in &scope.drops { if is_coroutine || drop.kind == DropKind::Value { @@ -1318,7 +1318,7 @@ impl<'a, 'tcx: 'a> Builder<'a, 'tcx> { blocks[ROOT_NODE] = continue_block; drops.build_mir::(&mut self.cfg, &mut blocks); - let is_coroutine = self.coroutine_kind.is_some(); + let is_coroutine = self.coroutine.is_some(); // Link the exit drop tree to unwind drop tree. if drops.drops.iter().any(|(drop, _)| drop.kind == DropKind::Value) { @@ -1355,7 +1355,7 @@ impl<'a, 'tcx: 'a> Builder<'a, 'tcx> { /// Build the unwind and coroutine drop trees. pub(crate) fn build_drop_trees(&mut self) { - if self.coroutine_kind.is_some() { + if self.coroutine.is_some() { self.build_coroutine_drop_trees(); } else { Self::build_unwind_tree(