Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New mir-opt pass to simplify gotos with const values #77486

Closed
wants to merge 12 commits into from
9 changes: 8 additions & 1 deletion compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,14 @@ pub enum StatementKind<'tcx> {
}

impl<'tcx> StatementKind<'tcx> {
pub fn as_assign_mut(&mut self) -> Option<&mut Box<(Place<'tcx>, Rvalue<'tcx>)>> {
pub fn as_assign_mut(&mut self) -> Option<&mut (Place<'tcx>, Rvalue<'tcx>)> {
match self {
StatementKind::Assign(x) => Some(x),
_ => None,
}
}

pub fn as_assign(&self) -> Option<&(Place<'tcx>, Rvalue<'tcx>)> {
match self {
StatementKind::Assign(x) => Some(x),
_ => None,
Expand Down
16 changes: 16 additions & 0 deletions compiler/rustc_middle/src/mir/terminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,22 @@ impl<'tcx> TerminatorKind<'tcx> {
| TerminatorKind::FalseUnwind { ref mut unwind, .. } => Some(unwind),
}
}

pub fn as_switch(&self) -> Option<(&Operand<'tcx>, Ty<'tcx>, &SwitchTargets)> {
match self {
TerminatorKind::SwitchInt { discr, switch_ty, targets } => {
Some((discr, switch_ty, targets))
}
_ => None,
}
}

pub fn as_goto(&self) -> Option<BasicBlock> {
match self {
TerminatorKind::Goto { target } => Some(*target),
_ => None,
}
}
}

impl<'tcx> Debug for TerminatorKind<'tcx> {
Expand Down
119 changes: 119 additions & 0 deletions compiler/rustc_mir/src/transform/const_goto.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
//! This pass optimizes the following sequence
//! ```rust
//! bb2: {
//! _2 = const true;
//! goto -> bb3;
//! }
//!
//! bb3: {
//! switchInt(_2) -> [false: bb4, otherwise: bb5];
//! }
//! ```
//! into
//! ```rust
//! bb2: {
//! _2 = const true;
//! goto -> bb5;
//! }
//! ```

use crate::transform::MirPass;
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;
use rustc_middle::{mir::visit::Visitor, ty::ParamEnv};

use super::simplify::{simplify_cfg, simplify_locals};

pub struct ConstGoto;

impl<'tcx> MirPass<'tcx> for ConstGoto {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("Running ConstGoto on {:?}", body.source);
let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
let mut opt_finder =
ConstGotoOptimizationFinder { tcx, body, optimizations: vec![], param_env };
opt_finder.visit_body(body);
let should_simplify = !opt_finder.optimizations.is_empty();
for opt in opt_finder.optimizations {
let terminator = body.basic_blocks_mut()[opt.bb_with_goto].terminator_mut();
let new_goto = TerminatorKind::Goto { target: opt.target_to_use_in_goto };
debug!("SUCCESS: replacing `{:?}` with `{:?}`", terminator.kind, new_goto);
terminator.kind = new_goto;
}

// if we applied optimizations, we potentially have some cfg to cleanup to
// make it easier for further passes
if should_simplify {
simplify_cfg(body);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @rust-lang/wg-mir-opt we need to figure out a way to dump the mir of an optimizations both before its internal cleanup and after. The test diffs are very noisy if we do the cleanup together with the optimization.

Maybe we could not do these immediate cleanups, and instead leave a flag in the mir body that states that it needs a simplification and the following cleanup optimizations just skip if the flag is not set?

simplify_locals(body, tcx);
}
}
}

impl<'a, 'tcx> Visitor<'tcx> for ConstGotoOptimizationFinder<'a, 'tcx> {
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
let _: Option<_> = try {
let target = terminator.kind.as_goto()?;
// We only apply this optimization if the last statement is a const assignment
let last_statement = self.body.basic_blocks()[location.block].statements.last()?;

if let (place, Rvalue::Use(Operand::Constant(_const))) =
last_statement.kind.as_assign()?
{
// We found a constant being assigned to `place`.
// Now check that the target of this Goto switches on this place.
let target_bb = &self.body.basic_blocks()[target];

// FIXME(simonvandel): We are conservative here when we don't allow
// any statements in the target basic block.
// This could probably be relaxed to allow `StorageDead`s which could be
// copied to the predecessor of this block.
if !target_bb.statements.is_empty() {
None?
}

let target_bb_terminator = target_bb.terminator();
let (discr, switch_ty, targets) = target_bb_terminator.kind.as_switch()?;
if discr.place() == Some(*place) {
// We now know that the Switch matches on the const place, and it is statementless
// Now find which value in the Switch matches the const value.
let const_value =
_const.literal.try_eval_bits(self.tcx, self.param_env, switch_ty)?;
let found_value_idx_option = targets
.iter()
.enumerate()
.find(|(_, (value, _))| const_value == *value)
.map(|(idx, _)| idx);

let target_to_use_in_goto =
if let Some(found_value_idx) = found_value_idx_option {
targets.iter().nth(found_value_idx).unwrap().1
} else {
// If we did not find the const value in values, it must be the otherwise case
targets.otherwise()
};

self.optimizations.push(OptimizationToApply {
bb_with_goto: location.block,
target_to_use_in_goto,
});
}
}
Some(())
};

self.super_terminator(terminator, location);
}
}

struct OptimizationToApply {
bb_with_goto: BasicBlock,
target_to_use_in_goto: BasicBlock,
}

pub struct ConstGotoOptimizationFinder<'a, 'tcx> {
tcx: TyCtxt<'tcx>,
body: &'a Body<'tcx>,
param_env: ParamEnv<'tcx>,
optimizations: Vec<OptimizationToApply>,
}
2 changes: 2 additions & 0 deletions compiler/rustc_mir/src/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub mod check_consts;
pub mod check_packed_ref;
pub mod check_unsafety;
pub mod cleanup_post_borrowck;
pub mod const_goto;
pub mod const_prop;
pub mod deaggregator;
pub mod dest_prop;
Expand Down Expand Up @@ -388,6 +389,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {

// The main optimizations that we do on MIR.
let optimizations: &[&dyn MirPass<'tcx>] = &[
&const_goto::ConstGoto,
&remove_unneeded_drops::RemoveUnneededDrops,
&match_branches::MatchBranchSimplification,
// inst combine is after MatchBranchSimplification to clean up Ne(_1, false)
Expand Down
77 changes: 40 additions & 37 deletions compiler/rustc_mir/src/transform/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,48 +315,51 @@ pub fn remove_dead_blocks(body: &mut Body<'_>) {
}
}

pub struct SimplifyLocals;

impl<'tcx> MirPass<'tcx> for SimplifyLocals {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("running SimplifyLocals on {:?}", body.source);

// First, we're going to get a count of *actual* uses for every `Local`.
// Take a look at `DeclMarker::visit_local()` to see exactly what is ignored.
let mut used_locals = {
let mut marker = DeclMarker::new(body);
marker.visit_body(&body);

marker.local_counts
};

let arg_count = body.arg_count;
pub fn simplify_locals<'tcx>(body: &mut Body<'tcx>, tcx: TyCtxt<'tcx>) {
// First, we're going to get a count of *actual* uses for every `Local`.
// Take a look at `DeclMarker::visit_local()` to see exactly what is ignored.
let mut used_locals = {
let mut marker = DeclMarker::new(body);
marker.visit_body(&body);

marker.local_counts
};

let arg_count = body.arg_count;

// Next, we're going to remove any `Local` with zero actual uses. When we remove those
// `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals`
// count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
// `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
// fixedpoint where there are no more unused locals.
loop {
let mut remove_statements = RemoveStatements::new(&mut used_locals, arg_count, tcx);
remove_statements.visit_body(body);

if !remove_statements.modified {
break;
}
}

// Next, we're going to remove any `Local` with zero actual uses. When we remove those
// `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals`
// count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
// `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
// fixedpoint where there are no more unused locals.
loop {
let mut remove_statements = RemoveStatements::new(&mut used_locals, arg_count, tcx);
remove_statements.visit_body(body);
// Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the `Local`s.
let map = make_local_map(&mut body.local_decls, used_locals, arg_count);

if !remove_statements.modified {
break;
}
}
// Only bother running the `LocalUpdater` if we actually found locals to remove.
if map.iter().any(Option::is_none) {
// Update references to all vars and tmps now
let mut updater = LocalUpdater { map, tcx };
updater.visit_body(body);

// Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the `Local`s.
let map = make_local_map(&mut body.local_decls, used_locals, arg_count);
body.local_decls.shrink_to_fit();
}
}

// Only bother running the `LocalUpdater` if we actually found locals to remove.
if map.iter().any(Option::is_none) {
// Update references to all vars and tmps now
let mut updater = LocalUpdater { map, tcx };
updater.visit_body(body);
pub struct SimplifyLocals;

body.local_decls.shrink_to_fit();
}
impl<'tcx> MirPass<'tcx> for SimplifyLocals {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("running SimplifyLocals on {:?}", body.source);
simplify_locals(body, tcx);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl<'tcx> MirPass<'tcx> for SimplifyComparisonIntegral {
// we convert the move in the comparison statement to a copy.

// unwrap is safe as we know this statement is an assign
let box (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap();
let (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap();

use Operand::*;
match rhs {
Expand Down
52 changes: 52 additions & 0 deletions src/test/mir-opt/const_goto.issue_77355_opt.ConstGoto.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
- // MIR for `issue_77355_opt` before ConstGoto
+ // MIR for `issue_77355_opt` after ConstGoto

fn issue_77355_opt(_1: Foo) -> u64 {
debug num => _1; // in scope 0 at $DIR/const_goto.rs:11:20: 11:23
let mut _0: u64; // return place in scope 0 at $DIR/const_goto.rs:11:33: 11:36
- let mut _2: bool; // in scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
- let mut _3: isize; // in scope 0 at $DIR/const_goto.rs:12:22: 12:28
+ let mut _2: isize; // in scope 0 at $DIR/const_goto.rs:12:22: 12:28

bb0: {
- StorageLive(_2); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
- _3 = discriminant(_1); // scope 0 at $DIR/const_goto.rs:12:22: 12:28
- switchInt(move _3) -> [1_isize: bb2, 2_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/const_goto.rs:12:22: 12:28
+ _2 = discriminant(_1); // scope 0 at $DIR/const_goto.rs:12:22: 12:28
+ switchInt(move _2) -> [1_isize: bb2, 2_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/const_goto.rs:12:22: 12:28
}

bb1: {
- _2 = const false; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
- goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
- }
-
- bb2: {
- _2 = const true; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
- goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
- }
-
- bb3: {
- switchInt(_2) -> [false: bb4, otherwise: bb5]; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
- }
-
- bb4: {
_0 = const 42_u64; // scope 0 at $DIR/const_goto.rs:12:53: 12:55
- goto -> bb6; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
+ goto -> bb3; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
}

- bb5: {
+ bb2: {
_0 = const 23_u64; // scope 0 at $DIR/const_goto.rs:12:41: 12:43
- goto -> bb6; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
+ goto -> bb3; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
}

- bb6: {
- StorageDead(_2); // scope 0 at $DIR/const_goto.rs:13:1: 13:2
+ bb3: {
return; // scope 0 at $DIR/const_goto.rs:13:2: 13:2
}
}

16 changes: 16 additions & 0 deletions src/test/mir-opt/const_goto.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
pub enum Foo {
A,
B,
C,
D,
E,
F,
}

// EMIT_MIR const_goto.issue_77355_opt.ConstGoto.diff
fn issue_77355_opt(num: Foo) -> u64 {
if matches!(num, Foo::B | Foo::C) { 23 } else { 42 }
}
fn main() {
issue_77355_opt(Foo::A);
}
51 changes: 51 additions & 0 deletions src/test/mir-opt/const_goto_const_eval_fail.f.ConstGoto.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
- // MIR for `f` before ConstGoto
+ // MIR for `f` after ConstGoto

fn f() -> u64 {
let mut _0: u64; // return place in scope 0 at $DIR/const_goto_const_eval_fail.rs:6:44: 6:47
let mut _1: bool; // in scope 0 at $DIR/const_goto_const_eval_fail.rs:7:11: 12:6
let mut _2: i32; // in scope 0 at $DIR/const_goto_const_eval_fail.rs:8:15: 8:16

bb0: {
StorageLive(_1); // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:11: 12:6
StorageLive(_2); // scope 0 at $DIR/const_goto_const_eval_fail.rs:8:15: 8:16
_2 = const A; // scope 0 at $DIR/const_goto_const_eval_fail.rs:8:15: 8:16
switchInt(_2) -> [1_i32: bb2, 2_i32: bb2, 3_i32: bb2, otherwise: bb1]; // scope 0 at $DIR/const_goto_const_eval_fail.rs:9:13: 9:14
}

bb1: {
_1 = const true; // scope 0 at $DIR/const_goto_const_eval_fail.rs:10:18: 10:22
goto -> bb3; // scope 0 at $DIR/const_goto_const_eval_fail.rs:8:9: 11:10
}

bb2: {
_1 = const B; // scope 0 at $DIR/const_goto_const_eval_fail.rs:9:26: 9:27
- goto -> bb3; // scope 0 at $DIR/const_goto_const_eval_fail.rs:8:9: 11:10
+ switchInt(_1) -> [false: bb4, otherwise: bb3]; // scope 0 at $DIR/const_goto_const_eval_fail.rs:13:9: 13:14
}

bb3: {
- switchInt(_1) -> [false: bb5, otherwise: bb4]; // scope 0 at $DIR/const_goto_const_eval_fail.rs:13:9: 13:14
- }
-
- bb4: {
_0 = const 2_u64; // scope 0 at $DIR/const_goto_const_eval_fail.rs:14:17: 14:18
- goto -> bb6; // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:5: 15:6
+ goto -> bb5; // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:5: 15:6
}

- bb5: {
+ bb4: {
_0 = const 1_u64; // scope 0 at $DIR/const_goto_const_eval_fail.rs:13:18: 13:19
- goto -> bb6; // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:5: 15:6
+ goto -> bb5; // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:5: 15:6
}

- bb6: {
+ bb5: {
StorageDead(_2); // scope 0 at $DIR/const_goto_const_eval_fail.rs:16:1: 16:2
StorageDead(_1); // scope 0 at $DIR/const_goto_const_eval_fail.rs:16:1: 16:2
return; // scope 0 at $DIR/const_goto_const_eval_fail.rs:16:2: 16:2
}
}

Loading