1
1
use rustc_middle:: mir:: patch:: MirPatch ;
2
2
use rustc_middle:: mir:: * ;
3
- use rustc_middle:: ty:: { self , Ty , TyCtxt } ;
3
+ use rustc_middle:: ty:: { Ty , TyCtxt } ;
4
4
use std:: fmt:: Debug ;
5
5
6
6
use super :: simplify:: simplify_cfg;
@@ -11,6 +11,7 @@ use super::simplify::simplify_cfg;
11
11
/// let y: Option<()>;
12
12
/// match (x,y) {
13
13
/// (Some(_), Some(_)) => {0},
14
+ /// (None, None) => {2},
14
15
/// _ => {1}
15
16
/// }
16
17
/// ```
@@ -23,10 +24,10 @@ use super::simplify::simplify_cfg;
23
24
/// if discriminant_x == discriminant_y {
24
25
/// match x {
25
26
/// Some(_) => 0,
26
- /// _ => 1, // <----
27
- /// } // | Actually the same bb
28
- /// } else { // |
29
- /// 1 // <--------------
27
+ /// None => 2,
28
+ /// }
29
+ /// } else {
30
+ /// 1
30
31
/// }
31
32
/// ```
32
33
///
@@ -47,18 +48,18 @@ use super::simplify::simplify_cfg;
47
48
/// | | |
48
49
/// ================= | | |
49
50
/// | BBU | <-| | | ============================
50
- /// |---------------| | \-------> | BBD |
51
- /// |---------------| | | |--------------------------|
52
- /// | unreachable | | | | _dl = discriminant(P) |
53
- /// ================= | | |--------------------------|
54
- /// | | | switchInt(_dl) |
55
- /// ================= | | | d | ---> BBD.2
51
+ /// |---------------| \-------> | BBD |
52
+ /// |---------------| | |--------------------------|
53
+ /// | unreachable | | | _dl = discriminant(P) |
54
+ /// ================= | |--------------------------|
55
+ /// | | switchInt(_dl) |
56
+ /// ================= | | d | ---> BBD.2
56
57
/// | BB9 | <--------------- | otherwise |
57
58
/// |---------------| ============================
58
59
/// | ... |
59
60
/// =================
60
61
/// ```
61
- /// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9` . In the
62
+ /// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
62
63
/// code:
63
64
/// - `BB1` is `parent` and `BBC, BBD` are children
64
65
/// - `P` is `child_place`
@@ -78,7 +79,7 @@ use super::simplify::simplify_cfg;
78
79
/// |---------------------| | | switchInt(Q) |
79
80
/// | switchInt(_t) | | | c | ---> BBC.2
80
81
/// | false | --------/ | d | ---> BBD.2
81
- /// | otherwise | ------- --------- | otherwise |
82
+ /// | otherwise | / --------- | otherwise |
82
83
/// ======================= | ============================
83
84
/// |
84
85
/// ================= |
@@ -87,16 +88,11 @@ use super::simplify::simplify_cfg;
87
88
/// | ... |
88
89
/// =================
89
90
/// ```
90
- ///
91
- /// This is only correct for some `P`, since `P` is now computed outside the original `switchInt`.
92
- /// The filter on which `P` are allowed (together with discussion of its correctness) is found in
93
- /// `may_hoist`.
94
91
pub struct EarlyOtherwiseBranch ;
95
92
96
93
impl < ' tcx > MirPass < ' tcx > for EarlyOtherwiseBranch {
97
94
fn is_enabled ( & self , sess : & rustc_session:: Session ) -> bool {
98
- // unsound: https://github.com/rust-lang/rust/issues/95162
99
- sess. mir_opt_level ( ) >= 3 && sess. opts . unstable_opts . unsound_mir_opts
95
+ sess. mir_opt_level ( ) >= 2
100
96
}
101
97
102
98
fn run_pass ( & self , tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
@@ -172,7 +168,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
172
168
} ;
173
169
( value, targets. target_for_value ( value) )
174
170
} ) ;
175
- let eq_targets = SwitchTargets :: new ( eq_new_targets, opt_data. destination ) ;
171
+ // The otherwise either is the same target branch or an unreachable.
172
+ let eq_targets = SwitchTargets :: new ( eq_new_targets, parent_targets. otherwise ( ) ) ;
176
173
177
174
// Create `bbEq` in example above
178
175
let eq_switch = BasicBlockData :: new ( Some ( Terminator {
@@ -217,85 +214,6 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
217
214
}
218
215
}
219
216
220
- /// Returns true if computing the discriminant of `place` may be hoisted out of the branch
221
- fn may_hoist < ' tcx > ( tcx : TyCtxt < ' tcx > , body : & Body < ' tcx > , place : Place < ' tcx > ) -> bool {
222
- // FIXME(JakobDegen): This is unsound. Someone could write code like this:
223
- // ```rust
224
- // let Q = val;
225
- // if discriminant(P) == otherwise {
226
- // let ptr = &mut Q as *mut _ as *mut u8;
227
- // unsafe { *ptr = 10; } // Any invalid value for the type
228
- // }
229
- //
230
- // match P {
231
- // A => match Q {
232
- // A => {
233
- // // code
234
- // }
235
- // _ => {
236
- // // don't use Q
237
- // }
238
- // }
239
- // _ => {
240
- // // don't use Q
241
- // }
242
- // };
243
- // ```
244
- //
245
- // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
246
- // invalid value, which is UB.
247
- //
248
- // In order to fix this, we would either need to show that the discriminant computation of
249
- // `place` is computed in all branches, including the `otherwise` branch, or we would need
250
- // another analysis pass to determine that the place is fully initialized. It might even be best
251
- // to have the hoisting be performed in a different pass and just do the CFG changing in this
252
- // pass.
253
- for ( place, proj) in place. iter_projections ( ) {
254
- match proj {
255
- // Dereferencing in the computation of `place` might cause issues from one of two
256
- // categories. First, the referent might be invalid. We protect against this by
257
- // dereferencing references only (not pointers). Second, the use of a reference may
258
- // invalidate other references that are used later (for aliasing reasons). Consider
259
- // where such an invalidated reference may appear:
260
- // - In `Q`: Not possible since `Q` is used as the operand of a `SwitchInt` and so
261
- // cannot contain referenced data.
262
- // - In `BBU`: Not possible since that block contains only the `unreachable` terminator
263
- // - In `BBC.2, BBD.2`: Not possible, since `discriminant(P)` was computed prior to
264
- // reaching that block in the input to our transformation, and so any data
265
- // invalidated by that computation could not have been used there.
266
- // - In `BB9`: Not possible since control flow might have reached `BB9` via the
267
- // `otherwise` branch in `BBC, BBD` in the input to our transformation, which would
268
- // have invalidated the data when computing `discriminant(P)`
269
- // So dereferencing here is correct.
270
- ProjectionElem :: Deref => match place. ty ( body. local_decls ( ) , tcx) . ty . kind ( ) {
271
- ty:: Ref ( ..) => { }
272
- _ => return false ,
273
- } ,
274
- // Field projections are always valid
275
- ProjectionElem :: Field ( ..) => { }
276
- // We cannot allow
277
- // downcasts either, since the correctness of the downcast may depend on the parent
278
- // branch being taken. An easy example of this is
279
- // ```
280
- // Q = discriminant(_3)
281
- // P = (_3 as Variant)
282
- // ```
283
- // However, checking if the child and parent place are the same and only erroring then
284
- // is not sufficient either, since the `discriminant(_3) == 1` (or whatever) check may
285
- // be replaced by another optimization pass with any other condition that can be proven
286
- // equivalent.
287
- ProjectionElem :: Downcast ( ..) => {
288
- return false ;
289
- }
290
- // We cannot allow indexing since the index may be out of bounds.
291
- _ => {
292
- return false ;
293
- }
294
- }
295
- }
296
- true
297
- }
298
-
299
217
#[ derive( Debug ) ]
300
218
struct OptimizationData < ' tcx > {
301
219
destination : BasicBlock ,
@@ -315,18 +233,40 @@ fn evaluate_candidate<'tcx>(
315
233
return None ;
316
234
} ;
317
235
let parent_ty = parent_discr. ty ( body. local_decls ( ) , tcx) ;
318
- let parent_dest = {
319
- let poss = targets. otherwise ( ) ;
320
- // If the fallthrough on the parent is trivially unreachable, we can let the
321
- // children choose the destination
322
- if bbs[ poss] . statements . len ( ) == 0
323
- && bbs[ poss] . terminator ( ) . kind == TerminatorKind :: Unreachable
324
- {
325
- None
326
- } else {
327
- Some ( poss)
328
- }
329
- } ;
236
+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
237
+ // Someone could write code like this:
238
+ // ```rust
239
+ // let Q = val;
240
+ // if discriminant(P) == otherwise {
241
+ // let ptr = &mut Q as *mut _ as *mut u8;
242
+ // // It may be difficult for us to effectively determine whether values are valid.
243
+ // // Invalid values can come from all sorts of corners.
244
+ // unsafe { *ptr = 10; }
245
+ // }
246
+ //
247
+ // match P {
248
+ // A => match Q {
249
+ // A => {
250
+ // // code
251
+ // }
252
+ // _ => {
253
+ // // don't use Q
254
+ // }
255
+ // }
256
+ // _ => {
257
+ // // don't use Q
258
+ // }
259
+ // };
260
+ // ```
261
+ //
262
+ // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
263
+ // invalid value, which is UB.
264
+ // In order to fix this, **we would either need to show that the discriminant computation of
265
+ // `place` is computed in all branches**.
266
+ // FIXME(#95162) For the moment, we adopt a conservative approach and
267
+ // consider only the `otherwise` branch has no statements and an unreachable terminator.
268
+ return None ;
269
+ }
330
270
let ( _, child) = targets. iter ( ) . next ( ) ?;
331
271
let child_terminator = & bbs[ child] . terminator ( ) ;
332
272
let TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } =
@@ -344,13 +284,7 @@ fn evaluate_candidate<'tcx>(
344
284
let ( _, Rvalue :: Discriminant ( child_place) ) = & * * boxed else {
345
285
return None ;
346
286
} ;
347
- let destination = parent_dest. unwrap_or ( child_targets. otherwise ( ) ) ;
348
-
349
- // Verify that the optimization is legal in general
350
- // We can hoist evaluating the child discriminant out of the branch
351
- if !may_hoist ( tcx, body, * child_place) {
352
- return None ;
353
- }
287
+ let destination = child_targets. otherwise ( ) ;
354
288
355
289
// Verify that the optimization is legal for each branch
356
290
for ( value, child) in targets. iter ( ) {
@@ -411,5 +345,5 @@ fn verify_candidate_branch<'tcx>(
411
345
if let Some ( _) = iter. next ( ) {
412
346
return false ;
413
347
}
414
- return true ;
348
+ true
415
349
}
0 commit comments