diff --git a/compiler/rustc_middle/src/mir/terminator.rs b/compiler/rustc_middle/src/mir/terminator.rs index 06e63b0f3d995..c2aa015f4b7e7 100644 --- a/compiler/rustc_middle/src/mir/terminator.rs +++ b/compiler/rustc_middle/src/mir/terminator.rs @@ -243,8 +243,9 @@ impl AssertKind { DivisionByZero(_) => middle_assert_divide_by_zero, RemainderByZero(_) => middle_assert_remainder_by_zero, ResumedAfterReturn(CoroutineKind::Async(_)) => middle_assert_async_resume_after_return, - // FIXME(gen_blocks): custom error message for `gen` blocks - ResumedAfterReturn(CoroutineKind::Gen(_)) => middle_assert_async_resume_after_return, + ResumedAfterReturn(CoroutineKind::Gen(_)) => { + bug!("gen blocks can be resumed after they return and will keep returning `None`") + } ResumedAfterReturn(CoroutineKind::Coroutine) => { middle_assert_coroutine_resume_after_return } diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index 8fecff16a9198..50d244d2831d4 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -249,18 +249,34 @@ struct TransformVisitor<'tcx> { } impl<'tcx> TransformVisitor<'tcx> { - // Make a `CoroutineState` or `Poll` variant assignment. - // - // `core::ops::CoroutineState` only has single element tuple variants, - // so we can just write to the downcasted first field and then set the - // discriminant to the appropriate variant. - fn make_state( + fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock { + let block = BasicBlock::new(body.basic_blocks.len()); + + let source_info = SourceInfo::outermost(body.span); + + let (kind, idx) = self.coroutine_state_adt_and_variant_idx(true); + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); + let statements = vec![Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + ))), + source_info, + }]; + + body.basic_blocks_mut().push(BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }); + + block + } + + fn coroutine_state_adt_and_variant_idx( &self, - val: Operand<'tcx>, - source_info: SourceInfo, is_return: bool, - statements: &mut Vec>, - ) { + ) -> (AggregateKind<'tcx>, VariantIdx) { let idx = VariantIdx::new(match (is_return, self.coroutine_kind) { (true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete (false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded @@ -271,6 +287,22 @@ impl<'tcx> TransformVisitor<'tcx> { }); let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None); + (kind, idx) + } + + // Make a `CoroutineState` or `Poll` variant assignment. + // + // `core::ops::CoroutineState` only has single element tuple variants, + // so we can just write to the downcasted first field and then set the + // discriminant to the appropriate variant. + fn make_state( + &self, + val: Operand<'tcx>, + source_info: SourceInfo, + is_return: bool, + statements: &mut Vec>, + ) { + let (kind, idx) = self.coroutine_state_adt_and_variant_idx(is_return); match self.coroutine_kind { // `Poll::Pending` @@ -1285,10 +1317,13 @@ fn create_coroutine_resume_function<'tcx>( } if can_return { - cases.insert( - 1, - (RETURNED, insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))), - ); + let block = match coroutine_kind { + CoroutineKind::Async(_) | CoroutineKind::Coroutine => { + insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind)) + } + CoroutineKind::Gen(_) => transform.insert_none_ret_block(body), + }; + cases.insert(1, (RETURNED, block)); } insert_switch(body, cases, &transform, TerminatorKind::Unreachable); diff --git a/tests/ui/coroutine/gen_block_iterate.rs b/tests/ui/coroutine/gen_block_iterate.rs index 131dd6879360e..18e1bb8877233 100644 --- a/tests/ui/coroutine/gen_block_iterate.rs +++ b/tests/ui/coroutine/gen_block_iterate.rs @@ -25,6 +25,8 @@ fn main() { assert_eq!(iter.next(), Some(4)); assert_eq!(iter.next(), Some(5)); assert_eq!(iter.next(), None); + // `gen` blocks are fused + assert_eq!(iter.next(), None); let mut iter = moved(); assert_eq!(iter.next(), Some(42)); diff --git a/tests/ui/coroutine/gen_block_iterate_fused.rs b/tests/ui/coroutine/gen_block_iterate_fused.rs deleted file mode 100644 index 8ee6baf10602f..0000000000000 --- a/tests/ui/coroutine/gen_block_iterate_fused.rs +++ /dev/null @@ -1,19 +0,0 @@ -// revisions: next old -//compile-flags: --edition 2024 -Zunstable-options -//[next] compile-flags: -Ztrait-solver=next -// run-fail -#![feature(gen_blocks)] - -fn foo() -> impl Iterator { - gen { yield 42; for x in 3..6 { yield x } } -} - -fn main() { - let mut iter = foo(); - assert_eq!(iter.next(), Some(42)); - assert_eq!(iter.next(), Some(3)); - assert_eq!(iter.next(), Some(4)); - assert_eq!(iter.next(), Some(5)); - assert_eq!(iter.next(), None); - assert_eq!(iter.next(), None); -}