|
| 1 | +use rustc_index::IndexVec; |
| 2 | +use rustc_middle::bug; |
| 3 | +use rustc_middle::mir::*; |
| 4 | +use rustc_middle::ty::TyCtxt; |
| 5 | +use tracing::{debug, instrument, trace}; |
| 6 | + |
| 7 | +use std::mem; |
| 8 | + |
| 9 | +pub(super) struct BranchDuplicator; |
| 10 | + |
| 11 | +impl<'tcx> crate::MirPass<'tcx> for BranchDuplicator { |
| 12 | + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { |
| 13 | + sess.mir_opt_level() >= 2 |
| 14 | + } |
| 15 | + |
| 16 | + #[instrument(skip_all level = "debug")] |
| 17 | + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
| 18 | + let def_id = body.source.def_id(); |
| 19 | + debug!(?def_id); |
| 20 | + |
| 21 | + // Optimizing coroutines creates query cycles. |
| 22 | + if tcx.is_coroutine(def_id) { |
| 23 | + trace!("Skipped for coroutine {:?}", def_id); |
| 24 | + return; |
| 25 | + } |
| 26 | + |
| 27 | + let is_branch = |targets: &SwitchTargets| { |
| 28 | + targets.all_targets().len() == 2 |
| 29 | + || (targets.all_values().len() == 2 && body.basic_blocks[targets.otherwise()].is_empty_unreachable()) |
| 30 | + }; |
| 31 | + |
| 32 | + let mut candidates = Vec::new(); |
| 33 | + for (bb, bbdata) in body.basic_blocks.iter_enumerated() { |
| 34 | + if let TerminatorKind::SwitchInt { targets, .. } = &bbdata.terminator().kind |
| 35 | + && is_branch(targets) |
| 36 | + && let Ok(preds) = <[BasicBlock; 2]>::try_from(body.basic_blocks.predecessors()[bb].as_slice()) |
| 37 | + && preds.iter().copied().all(|p| matches!(body.basic_blocks[p].terminator().kind, TerminatorKind::Goto { .. })) |
| 38 | + && bbdata.statements.iter().all(|x| is_negligible(&x.kind)) |
| 39 | + { |
| 40 | + candidates.push((bb, preds)); |
| 41 | + } |
| 42 | + } |
| 43 | + |
| 44 | + if candidates.is_empty() { |
| 45 | + return; |
| 46 | + } |
| 47 | + |
| 48 | + let basic_blocks = body.basic_blocks.as_mut(); |
| 49 | + for (bb, [p0, p1]) in candidates { |
| 50 | + let bbdata = &mut basic_blocks[bb]; |
| 51 | + let statements = mem::take(&mut bbdata.statements); |
| 52 | + let unreachable = Terminator { |
| 53 | + source_info: bbdata.terminator().source_info, |
| 54 | + kind: TerminatorKind::Unreachable, |
| 55 | + }; |
| 56 | + let terminator = mem::replace(bbdata.terminator_mut(), unreachable); |
| 57 | + |
| 58 | + let pred0data = &mut basic_blocks[p0]; |
| 59 | + pred0data.statements.extend(statements.iter().cloned()); |
| 60 | + *pred0data.terminator_mut() = terminator.clone(); |
| 61 | + |
| 62 | + let pred1data = &mut basic_blocks[p1]; |
| 63 | + pred1data.statements.extend(statements); |
| 64 | + *pred1data.terminator_mut() = terminator; |
| 65 | + } |
| 66 | + } |
| 67 | + |
| 68 | + fn is_required(&self) -> bool { |
| 69 | + false |
| 70 | + } |
| 71 | +} |
| 72 | + |
| 73 | +fn is_negligible<'tcx>(stmt: &StatementKind<'tcx>) -> bool { |
| 74 | + use StatementKind::*; |
| 75 | + use Rvalue::*; |
| 76 | + match stmt { |
| 77 | + StorageLive(..) | StorageDead(..) => true, |
| 78 | + Assign(place_and_rvalue) => match &place_and_rvalue.1 { |
| 79 | + Ref(..) | RawPtr(..) | Discriminant(..) | NullaryOp(..) => true, |
| 80 | + _ => false, |
| 81 | + } |
| 82 | + _ => false, |
| 83 | + } |
| 84 | +} |
0 commit comments