Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1c207eb

Browse files
committedMar 27, 2024
Auto merge of rust-lang#122387 - DianQK:re-enable-early-otherwise-branch, r=<try>
Re-enable the early otherwise branch optimization Closes rust-lang#95162. Fixes rust-lang#119014. This is the first part of rust-lang#121397. An invalid enum discriminant can come from anywhere. We have to check to see if all successors contain the discriminant statement. This should have a pass to hoist instructions. r? cjgillot
2 parents 435b525 + 57d566e commit 1c207eb

16 files changed

+615
-278
lines changed
 

‎compiler/rustc_mir_transform/src/early_otherwise_branch.rs

+53-119
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use rustc_middle::mir::patch::MirPatch;
22
use rustc_middle::mir::*;
3-
use rustc_middle::ty::{self, Ty, TyCtxt};
3+
use rustc_middle::ty::{Ty, TyCtxt};
44
use std::fmt::Debug;
55

66
use super::simplify::simplify_cfg;
@@ -11,6 +11,7 @@ use super::simplify::simplify_cfg;
1111
/// let y: Option<()>;
1212
/// match (x,y) {
1313
/// (Some(_), Some(_)) => {0},
14+
/// (None, None) => {2},
1415
/// _ => {1}
1516
/// }
1617
/// ```
@@ -23,10 +24,10 @@ use super::simplify::simplify_cfg;
2324
/// if discriminant_x == discriminant_y {
2425
/// match x {
2526
/// Some(_) => 0,
26-
/// _ => 1, // <----
27-
/// } // | Actually the same bb
28-
/// } else { // |
29-
/// 1 // <--------------
27+
/// None => 2,
28+
/// }
29+
/// } else {
30+
/// 1
3031
/// }
3132
/// ```
3233
///
@@ -47,18 +48,18 @@ use super::simplify::simplify_cfg;
4748
/// | | |
4849
/// ================= | | |
4950
/// | BBU | <-| | | ============================
50-
/// |---------------| | \-------> | BBD |
51-
/// |---------------| | | |--------------------------|
52-
/// | unreachable | | | | _dl = discriminant(P) |
53-
/// ================= | | |--------------------------|
54-
/// | | | switchInt(_dl) |
55-
/// ================= | | | d | ---> BBD.2
51+
/// |---------------| \-------> | BBD |
52+
/// |---------------| | |--------------------------|
53+
/// | unreachable | | | _dl = discriminant(P) |
54+
/// ================= | |--------------------------|
55+
/// | | switchInt(_dl) |
56+
/// ================= | | d | ---> BBD.2
5657
/// | BB9 | <--------------- | otherwise |
5758
/// |---------------| ============================
5859
/// | ... |
5960
/// =================
6061
/// ```
61-
/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9`. In the
62+
/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
6263
/// code:
6364
/// - `BB1` is `parent` and `BBC, BBD` are children
6465
/// - `P` is `child_place`
@@ -78,7 +79,7 @@ use super::simplify::simplify_cfg;
7879
/// |---------------------| | | switchInt(Q) |
7980
/// | switchInt(_t) | | | c | ---> BBC.2
8081
/// | false | --------/ | d | ---> BBD.2
81-
/// | otherwise | ---------------- | otherwise |
82+
/// | otherwise | /--------- | otherwise |
8283
/// ======================= | ============================
8384
/// |
8485
/// ================= |
@@ -87,16 +88,11 @@ use super::simplify::simplify_cfg;
8788
/// | ... |
8889
/// =================
8990
/// ```
90-
///
91-
/// This is only correct for some `P`, since `P` is now computed outside the original `switchInt`.
92-
/// The filter on which `P` are allowed (together with discussion of its correctness) is found in
93-
/// `may_hoist`.
9491
pub struct EarlyOtherwiseBranch;
9592

9693
impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
9794
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
98-
// unsound: https://github.com/rust-lang/rust/issues/95162
99-
sess.mir_opt_level() >= 3 && sess.opts.unstable_opts.unsound_mir_opts
95+
sess.mir_opt_level() >= 2
10096
}
10197

10298
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
@@ -172,7 +168,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
172168
};
173169
(value, targets.target_for_value(value))
174170
});
175-
let eq_targets = SwitchTargets::new(eq_new_targets, opt_data.destination);
171+
// The otherwise either is the same target branch or an unreachable.
172+
let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());
176173

177174
// Create `bbEq` in example above
178175
let eq_switch = BasicBlockData::new(Some(Terminator {
@@ -217,85 +214,6 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
217214
}
218215
}
219216

220-
/// Returns true if computing the discriminant of `place` may be hoisted out of the branch
221-
fn may_hoist<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: Place<'tcx>) -> bool {
222-
// FIXME(JakobDegen): This is unsound. Someone could write code like this:
223-
// ```rust
224-
// let Q = val;
225-
// if discriminant(P) == otherwise {
226-
// let ptr = &mut Q as *mut _ as *mut u8;
227-
// unsafe { *ptr = 10; } // Any invalid value for the type
228-
// }
229-
//
230-
// match P {
231-
// A => match Q {
232-
// A => {
233-
// // code
234-
// }
235-
// _ => {
236-
// // don't use Q
237-
// }
238-
// }
239-
// _ => {
240-
// // don't use Q
241-
// }
242-
// };
243-
// ```
244-
//
245-
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
246-
// invalid value, which is UB.
247-
//
248-
// In order to fix this, we would either need to show that the discriminant computation of
249-
// `place` is computed in all branches, including the `otherwise` branch, or we would need
250-
// another analysis pass to determine that the place is fully initialized. It might even be best
251-
// to have the hoisting be performed in a different pass and just do the CFG changing in this
252-
// pass.
253-
for (place, proj) in place.iter_projections() {
254-
match proj {
255-
// Dereferencing in the computation of `place` might cause issues from one of two
256-
// categories. First, the referent might be invalid. We protect against this by
257-
// dereferencing references only (not pointers). Second, the use of a reference may
258-
// invalidate other references that are used later (for aliasing reasons). Consider
259-
// where such an invalidated reference may appear:
260-
// - In `Q`: Not possible since `Q` is used as the operand of a `SwitchInt` and so
261-
// cannot contain referenced data.
262-
// - In `BBU`: Not possible since that block contains only the `unreachable` terminator
263-
// - In `BBC.2, BBD.2`: Not possible, since `discriminant(P)` was computed prior to
264-
// reaching that block in the input to our transformation, and so any data
265-
// invalidated by that computation could not have been used there.
266-
// - In `BB9`: Not possible since control flow might have reached `BB9` via the
267-
// `otherwise` branch in `BBC, BBD` in the input to our transformation, which would
268-
// have invalidated the data when computing `discriminant(P)`
269-
// So dereferencing here is correct.
270-
ProjectionElem::Deref => match place.ty(body.local_decls(), tcx).ty.kind() {
271-
ty::Ref(..) => {}
272-
_ => return false,
273-
},
274-
// Field projections are always valid
275-
ProjectionElem::Field(..) => {}
276-
// We cannot allow
277-
// downcasts either, since the correctness of the downcast may depend on the parent
278-
// branch being taken. An easy example of this is
279-
// ```
280-
// Q = discriminant(_3)
281-
// P = (_3 as Variant)
282-
// ```
283-
// However, checking if the child and parent place are the same and only erroring then
284-
// is not sufficient either, since the `discriminant(_3) == 1` (or whatever) check may
285-
// be replaced by another optimization pass with any other condition that can be proven
286-
// equivalent.
287-
ProjectionElem::Downcast(..) => {
288-
return false;
289-
}
290-
// We cannot allow indexing since the index may be out of bounds.
291-
_ => {
292-
return false;
293-
}
294-
}
295-
}
296-
true
297-
}
298-
299217
#[derive(Debug)]
300218
struct OptimizationData<'tcx> {
301219
destination: BasicBlock,
@@ -315,18 +233,40 @@ fn evaluate_candidate<'tcx>(
315233
return None;
316234
};
317235
let parent_ty = parent_discr.ty(body.local_decls(), tcx);
318-
let parent_dest = {
319-
let poss = targets.otherwise();
320-
// If the fallthrough on the parent is trivially unreachable, we can let the
321-
// children choose the destination
322-
if bbs[poss].statements.len() == 0
323-
&& bbs[poss].terminator().kind == TerminatorKind::Unreachable
324-
{
325-
None
326-
} else {
327-
Some(poss)
328-
}
329-
};
236+
if !bbs[targets.otherwise()].is_empty_unreachable() {
237+
// Someone could write code like this:
238+
// ```rust
239+
// let Q = val;
240+
// if discriminant(P) == otherwise {
241+
// let ptr = &mut Q as *mut _ as *mut u8;
242+
// // It may be difficult for us to effectively determine whether values are valid.
243+
// // Invalid values can come from all sorts of corners.
244+
// unsafe { *ptr = 10; }
245+
// }
246+
//
247+
// match P {
248+
// A => match Q {
249+
// A => {
250+
// // code
251+
// }
252+
// _ => {
253+
// // don't use Q
254+
// }
255+
// }
256+
// _ => {
257+
// // don't use Q
258+
// }
259+
// };
260+
// ```
261+
//
262+
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
263+
// invalid value, which is UB.
264+
// In order to fix this, **we would either need to show that the discriminant computation of
265+
// `place` is computed in all branches**.
266+
// FIXME(#95162) For the moment, we adopt a conservative approach and
267+
// consider only the `otherwise` branch has no statements and an unreachable terminator.
268+
return None;
269+
}
330270
let (_, child) = targets.iter().next()?;
331271
let child_terminator = &bbs[child].terminator();
332272
let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } =
@@ -344,13 +284,7 @@ fn evaluate_candidate<'tcx>(
344284
let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
345285
return None;
346286
};
347-
let destination = parent_dest.unwrap_or(child_targets.otherwise());
348-
349-
// Verify that the optimization is legal in general
350-
// We can hoist evaluating the child discriminant out of the branch
351-
if !may_hoist(tcx, body, *child_place) {
352-
return None;
353-
}
287+
let destination = child_targets.otherwise();
354288

355289
// Verify that the optimization is legal for each branch
356290
for (value, child) in targets.iter() {
@@ -411,5 +345,5 @@ fn verify_candidate_branch<'tcx>(
411345
if let Some(_) = iter.next() {
412346
return false;
413347
}
414-
return true;
348+
true
415349
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//@ compile-flags: -O
2+
3+
#![crate_type = "lib"]
4+
5+
pub enum Enum {
6+
A(u32),
7+
B(u32),
8+
C(u32),
9+
}
10+
11+
#[no_mangle]
12+
pub fn foo(lhs: &Enum, rhs: &Enum) -> bool {
13+
// CHECK-LABEL: define{{.*}}i1 @foo(
14+
// CHECK-NOT: switch
15+
// CHECK-NOT: br
16+
// CHECK: [[SELECT:%.*]] = select
17+
// CHECK-NEXT: ret i1 [[SELECT]]
18+
// CHECK-NEXT: }
19+
match (lhs, rhs) {
20+
(Enum::A(lhs), Enum::A(rhs)) => lhs == rhs,
21+
(Enum::B(lhs), Enum::B(rhs)) => lhs == rhs,
22+
(Enum::C(lhs), Enum::C(rhs)) => lhs == rhs,
23+
_ => false,
24+
}
25+
}

‎tests/mir-opt/early_otherwise_branch.opt1.EarlyOtherwiseBranch.diff

+13-26
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
let mut _7: isize;
1313
let _8: u32;
1414
let _9: u32;
15-
+ let mut _10: isize;
16-
+ let mut _11: bool;
1715
scope 1 {
1816
debug a => _8;
1917
debug b => _9;
@@ -29,48 +27,37 @@
2927
StorageDead(_5);
3028
StorageDead(_4);
3129
_7 = discriminant((_3.0: std::option::Option<u32>));
32-
- switchInt(move _7) -> [1: bb2, otherwise: bb1];
33-
+ StorageLive(_10);
34-
+ _10 = discriminant((_3.1: std::option::Option<u32>));
35-
+ StorageLive(_11);
36-
+ _11 = Ne(_7, move _10);
37-
+ StorageDead(_10);
38-
+ switchInt(move _11) -> [0: bb4, otherwise: bb1];
30+
switchInt(move _7) -> [1: bb2, 0: bb1, otherwise: bb5];
3931
}
4032

4133
bb1: {
42-
+ StorageDead(_11);
4334
_0 = const 1_u32;
44-
- goto -> bb4;
45-
+ goto -> bb3;
35+
goto -> bb4;
4636
}
4737

4838
bb2: {
49-
- _6 = discriminant((_3.1: std::option::Option<u32>));
50-
- switchInt(move _6) -> [1: bb3, otherwise: bb1];
51-
- }
52-
-
53-
- bb3: {
39+
_6 = discriminant((_3.1: std::option::Option<u32>));
40+
switchInt(move _6) -> [1: bb3, 0: bb1, otherwise: bb5];
41+
}
42+
43+
bb3: {
5444
StorageLive(_8);
5545
_8 = (((_3.0: std::option::Option<u32>) as Some).0: u32);
5646
StorageLive(_9);
5747
_9 = (((_3.1: std::option::Option<u32>) as Some).0: u32);
5848
_0 = const 0_u32;
5949
StorageDead(_9);
6050
StorageDead(_8);
61-
- goto -> bb4;
62-
+ goto -> bb3;
51+
goto -> bb4;
6352
}
6453

65-
- bb4: {
66-
+ bb3: {
54+
bb4: {
6755
StorageDead(_3);
6856
return;
69-
+ }
70-
+
71-
+ bb4: {
72-
+ StorageDead(_11);
73-
+ switchInt(_7) -> [1: bb2, otherwise: bb1];
57+
}
58+
59+
bb5: {
60+
unreachable;
7461
}
7562
}
7663

There was a problem loading the remainder of the diff.

0 commit comments

Comments
 (0)
Failed to load comments.