Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge these copy statements that simplified the canonical enum clone method by GVN #129931

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,6 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// Now, we need to shrink the generated MIR.
&ref_prop::ReferencePropagation,
&sroa::ScalarReplacementOfAggregates,
&match_branches::MatchBranchSimplification,
// inst combine is after MatchBranchSimplification to clean up Ne(_1, false)
&multiple_return_terminators::MultipleReturnTerminators,
// After simplifycfg, it allows us to discover new opportunities for peephole
// optimizations.
Expand All @@ -604,6 +602,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&dead_store_elimination::DeadStoreElimination::Initial,
&gvn::GVN,
&simplify::SimplifyLocals::AfterGVN,
&match_branches::MatchBranchSimplification,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this run on clone shims too? Or do we want a trait-based solution to detect trivial clone impls?

&dataflow_const_prop::DataflowConstProp,
&single_use_consts::SingleUseConsts,
&o1(simplify_branches::SimplifyConstCondition::AfterConstProp),
Expand Down
221 changes: 220 additions & 1 deletion compiler/rustc_mir_transform/src/match_branches.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use std::iter;
use std::{iter, usize};

use rustc_const_eval::const_eval::mk_eval_cx_for_const_val;
use rustc_index::bit_set::BitSet;
use rustc_index::IndexSlice;
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::*;
use rustc_middle::ty;
use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
use rustc_middle::ty::util::Discr;
use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
use rustc_mir_dataflow::impls::{borrowed_locals, MaybeTransitiveLiveLocals};
use rustc_mir_dataflow::Analysis;
use rustc_target::abi::Integer;
use rustc_type_ir::TyKind::*;

Expand Down Expand Up @@ -48,6 +54,10 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
should_cleanup = true;
continue;
}
if simplify_to_copy(tcx, body, bb_idx, param_env).is_some() {
should_cleanup = true;
continue;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be made a standalone MirPass? This will be easier to have some analyses computed only once.


if should_cleanup {
Expand Down Expand Up @@ -519,3 +529,212 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
}
}
}

/// This is primarily used to merge these copy statements that simplified the canonical enum clone method by GVN.
/// The GVN simplified
/// ```ignore (syntax-highlighting-only)
/// match a {
/// Foo::A(x) => Foo::A(*x),
/// Foo::B => Foo::B
/// }
/// ```
/// to
/// ```ignore (syntax-highlighting-only)
/// match a {
/// Foo::A(_x) => a, // copy a
/// Foo::B => Foo::B
/// }
/// ```
/// This function will simplify into a copy statement.
fn simplify_to_copy<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
switch_bb_idx: BasicBlock,
param_env: ParamEnv<'tcx>,
) -> Option<()> {
// To save compile time, only consider the first BB has a switch terminator.
if switch_bb_idx != START_BLOCK {
return None;
}
let bbs = &body.basic_blocks;
// Check if the copy source matches the following pattern.
// _2 = discriminant(*_1); // "*_1" is the expected the copy source.
// switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
let &Statement {
kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(expected_src_place))),
..
} = bbs[switch_bb_idx].statements.last()?
else {
return None;
};
let expected_src_ty = expected_src_place.ty(body.local_decls(), tcx);
if !expected_src_ty.ty.is_enum() || expected_src_ty.variant_index.is_some() {
return None;
}
// To save compile time, only consider the copy source is assigned to the return place.
let expected_dest_place = Place::return_place();
let expected_dest_ty = expected_dest_place.ty(body.local_decls(), tcx);
if expected_dest_ty.ty != expected_src_ty.ty || expected_dest_ty.variant_index.is_some() {
return None;
}
let targets = match bbs[switch_bb_idx].terminator().kind {
TerminatorKind::SwitchInt { ref discr, ref targets, .. }
if discr.place() == Some(discr_place) =>
{
targets
}
_ => return None,
};
// We require that the possible target blocks all be distinct.
if !targets.is_distinct() {
return None;
}
if !bbs[targets.otherwise()].is_empty_unreachable() {
return None;
}
// Check that destinations are identical, and if not, then don't optimize this block.
let mut target_iter = targets.iter();
let first_terminator_kind = &bbs[target_iter.next().unwrap().1].terminator().kind;
if !target_iter
.all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind)
{
return None;
}

let borrowed_locals = borrowed_locals(body);
let mut live = None;

for (index, target_bb) in targets.iter() {
let stmts = &bbs[target_bb].statements;
if stmts.is_empty() {
return None;
}
if let [Statement { kind: StatementKind::Assign(box (place, rvalue)), .. }] =
bbs[target_bb].statements.as_slice()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment for the high-level pattern you are trying to match?

{
let dest_ty = place.ty(body.local_decls(), tcx);
if dest_ty.ty != expected_src_ty.ty || dest_ty.variant_index.is_some() {
return None;
}
let ty::Adt(def, _) = dest_ty.ty.kind() else {
return None;
};
if expected_dest_place != *place {
return None;
}
match rvalue {
// Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`.
Rvalue::Use(Operand::Constant(box constant))
if let Const::Val(const_, ty) = constant.const_ =>
{
let (ecx, op) =
mk_eval_cx_for_const_val(tcx.at(constant.span), param_env, const_, ty)?;
let variant = ecx.read_discriminant(&op).ok()?;
if !def.variants()[variant].fields.is_empty() {
return None;
}
let Discr { val, .. } = ty.discriminant_for_variant(tcx, variant)?;
if val != index {
return None;
}
}
Rvalue::Use(Operand::Copy(src_place)) if *src_place == expected_src_place => {}
// Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`.
Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields)
if fields.is_empty()
&& let Some(Discr { val, .. }) =
expected_src_ty.ty.discriminant_for_variant(tcx, *variant_index)
&& val == index => {}
_ => return None,
}
} else {
// If the BB contains more than one statement, we have to check if these statements can be ignored.
let mut lived_stmts: BitSet<usize> =
BitSet::new_filled(bbs[target_bb].statements.len());
let mut expected_copy_stmt = None;
for (statement_index, statement) in bbs[target_bb].statements.iter().enumerate().rev() {
let loc = Location { block: target_bb, statement_index };
if let StatementKind::Assign(assign) = &statement.kind {
if !assign.1.is_safe_to_remove() {
return None;
}
}
match &statement.kind {
StatementKind::Assign(box (place, _))
| StatementKind::SetDiscriminant { place: box place, .. }
| StatementKind::Deinit(box place) => {
if place.is_indirect() || borrowed_locals.contains(place.local) {
return None;
}
let live = live.get_or_insert_with(|| {
MaybeTransitiveLiveLocals::new(&borrowed_locals)
.into_engine(tcx, body)
.iterate_to_fixpoint()
.into_results_cursor(body)
});
live.seek_before_primary_effect(loc);
if !live.get().contains(place.local) {
lived_stmts.remove(statement_index);
} else if let StatementKind::Assign(box (
_,
Rvalue::Use(Operand::Copy(src_place)),
)) = statement.kind
&& expected_copy_stmt.is_none()
&& expected_src_place == src_place
&& expected_dest_place == *place
{
// There is only one statement that cannot be ignored that can be used as an expected copy statement.
expected_copy_stmt = Some(statement_index);
} else {
return None;
}
}
StatementKind::StorageLive(_)
| StatementKind::StorageDead(_)
| StatementKind::Nop => (),

StatementKind::Retag(_, _)
| StatementKind::Coverage(_)
| StatementKind::Intrinsic(_)
| StatementKind::ConstEvalCounter
| StatementKind::PlaceMention(_)
| StatementKind::FakeRead(_)
| StatementKind::AscribeUserType(_, _) => {
return None;
}
}
}
let expected_copy_stmt = expected_copy_stmt?;
// We can ignore the paired StorageLive and StorageDead.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we perform the same analysis on this statement that we do in the single statement branch above?

let mut storage_live_locals: BitSet<Local> = BitSet::new_empty(body.local_decls.len());
for stmt_index in lived_stmts.iter() {
let statement = &bbs[target_bb].statements[stmt_index];
match &statement.kind {
StatementKind::Assign(_) if expected_copy_stmt == stmt_index => {}
StatementKind::StorageLive(local)
if *local != expected_dest_place.local
&& storage_live_locals.insert(*local) => {}
StatementKind::StorageDead(local)
if *local != expected_dest_place.local
&& storage_live_locals.remove(*local) => {}
StatementKind::Nop => {}
_ => return None,
}
}
if !storage_live_locals.is_empty() {
return None;
}
}
}
let statement_index = bbs[switch_bb_idx].statements.len();
let parent_end = Location { block: switch_bb_idx, statement_index };
let mut patch = MirPatch::new(body);
patch.add_assign(
parent_end,
expected_dest_place,
Rvalue::Use(Operand::Copy(expected_src_place)),
);
patch.patch_terminator(switch_bb_idx, first_terminator_kind.clone());
patch.apply(body);
Some(())
}
14 changes: 8 additions & 6 deletions tests/codegen/match-optimizes-away.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
//@ compile-flags: -O
//@ compile-flags: -O -Cno-prepopulate-passes

#![crate_type = "lib"]

pub enum Three {
Expand All @@ -19,8 +19,9 @@ pub enum Four {
#[no_mangle]
pub fn three_valued(x: Three) -> Three {
// CHECK-LABEL: @three_valued
// CHECK-NEXT: {{^.*:$}}
// CHECK-NEXT: ret i8 %0
// CHECK-SAME: (i8{{.*}} [[X:%x]])
// CHECK-NEXT: start:
// CHECK-NEXT: ret i8 [[X]]
match x {
Three::A => Three::A,
Three::B => Three::B,
Expand All @@ -31,8 +32,9 @@ pub fn three_valued(x: Three) -> Three {
#[no_mangle]
pub fn four_valued(x: Four) -> Four {
// CHECK-LABEL: @four_valued
// CHECK-NEXT: {{^.*:$}}
// CHECK-NEXT: ret i16 %0
// CHECK-SAME: (i16{{.*}} [[X:%x]])
// CHECK-NEXT: start:
// CHECK-NEXT: ret i16 [[X]]
match x {
Four::A => Four::A,
Four::B => Four::B,
Expand Down
7 changes: 1 addition & 6 deletions tests/codegen/try_question_mark_nop.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
//@ compile-flags: -O -Z merge-functions=disabled --edition=2021
//@ only-x86_64
// FIXME: Remove the `min-llvm-version`.
//@ min-llvm-version: 19

#![crate_type = "lib"]
#![feature(try_blocks)]

use std::ops::ControlFlow::{self, Break, Continue};
use std::ptr::NonNull;

// FIXME: The `trunc` and `select` instructions can be eliminated.
// CHECK-LABEL: @option_nop_match_32
#[no_mangle]
pub fn option_nop_match_32(x: Option<u32>) -> Option<u32> {
// CHECK: start:
// CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i32 %0 to i1
// CHECK-NEXT: [[FIRST:%.*]] = select i1 [[TRUNC]], i32 %0
// CHECK-NEXT: insertvalue { i32, i32 } poison, i32 [[FIRST]]
// CHECK-NEXT: insertvalue { i32, i32 }
// CHECK-NEXT: insertvalue { i32, i32 }
// CHECK-NEXT: ret { i32, i32 }
match x {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
let _6: *mut [bool; 0];
scope 6 {
scope 10 (inlined NonNull::<[bool; 0]>::new_unchecked) {
let mut _8: bool;
let _9: ();
let mut _10: *mut ();
let mut _11: *const [bool; 0];
let _8: ();
let mut _9: *mut ();
let mut _10: *const [bool; 0];
scope 11 (inlined core::ub_checks::check_language_ub) {
let mut _11: bool;
scope 12 (inlined core::ub_checks::check_language_ub::runtime) {
}
}
Expand All @@ -44,18 +44,18 @@
StorageLive(_1);
StorageLive(_2);
StorageLive(_3);
StorageLive(_9);
StorageLive(_8);
StorageLive(_4);
StorageLive(_5);
StorageLive(_6);
StorageLive(_7);
_7 = const 1_usize;
_6 = const {0x1 as *mut [bool; 0]};
StorageDead(_7);
StorageLive(_10);
StorageLive(_11);
StorageLive(_8);
_8 = UbChecks();
switchInt(move _8) -> [0: bb4, otherwise: bb2];
_11 = UbChecks();
switchInt(copy _11) -> [0: bb4, otherwise: bb2];
}

bb1: {
Expand All @@ -64,28 +64,28 @@
}

bb2: {
StorageLive(_10);
_10 = const {0x1 as *mut ()};
_9 = NonNull::<T>::new_unchecked::precondition_check(const {0x1 as *mut ()}) -> [return: bb3, unwind unreachable];
StorageLive(_9);
_9 = const {0x1 as *mut ()};
_8 = NonNull::<T>::new_unchecked::precondition_check(const {0x1 as *mut ()}) -> [return: bb3, unwind unreachable];
}

bb3: {
StorageDead(_10);
StorageDead(_9);
goto -> bb4;
}

bb4: {
StorageDead(_8);
_11 = const {0x1 as *const [bool; 0]};
_10 = const {0x1 as *const [bool; 0]};
_5 = const NonNull::<[bool; 0]> {{ pointer: {0x1 as *const [bool; 0]} }};
StorageDead(_11);
StorageDead(_10);
StorageDead(_6);
_4 = const Unique::<[bool; 0]> {{ pointer: NonNull::<[bool; 0]> {{ pointer: {0x1 as *const [bool; 0]} }}, _marker: PhantomData::<[bool; 0]> }};
StorageDead(_5);
_3 = const Unique::<[bool]> {{ pointer: NonNull::<[bool]> {{ pointer: Indirect { alloc_id: ALLOC0, offset: Size(0 bytes) }: *const [bool] }}, _marker: PhantomData::<[bool]> }};
StorageDead(_4);
_2 = const Box::<[bool]>(Unique::<[bool]> {{ pointer: NonNull::<[bool]> {{ pointer: Indirect { alloc_id: ALLOC1, offset: Size(0 bytes) }: *const [bool] }}, _marker: PhantomData::<[bool]> }}, std::alloc::Global);
StorageDead(_9);
StorageDead(_8);
StorageDead(_3);
_1 = const A {{ foo: Box::<[bool]>(Unique::<[bool]> {{ pointer: NonNull::<[bool]> {{ pointer: Indirect { alloc_id: ALLOC2, offset: Size(0 bytes) }: *const [bool] }}, _marker: PhantomData::<[bool]> }}, std::alloc::Global) }};
StorageDead(_2);
Expand Down
Loading
Loading