Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e929c7f

Browse files
committedJun 2, 2024
Change f32::midpoint to upcast to f64
This has been verified by kani as a correct optimization see: rust-lang#110840 (comment) The new implementation is branchless, and only differs in which NaN values are produced (if any are produced at all). Which is fine to change. Aside from NaN handling, this implementation produces bitwise identical results to the original implementation. The new implementation is gated on targets that have a fast 64-bit floating point implementation in hardware, and on WASM.
1 parent 0b5ada2 commit e929c7f

File tree

2 files changed

+62
-22
lines changed

2 files changed

+62
-22
lines changed
 

‎core/src/num/f32.rs

+36-19
Original file line numberDiff line numberDiff line change
@@ -1016,25 +1016,42 @@ impl f32 {
10161016
/// ```
10171017
#[unstable(feature = "num_midpoint", issue = "110840")]
10181018
pub fn midpoint(self, other: f32) -> f32 {
1019-
const LO: f32 = f32::MIN_POSITIVE * 2.;
1020-
const HI: f32 = f32::MAX / 2.;
1021-
1022-
let (a, b) = (self, other);
1023-
let abs_a = a.abs_private();
1024-
let abs_b = b.abs_private();
1025-
1026-
if abs_a <= HI && abs_b <= HI {
1027-
// Overflow is impossible
1028-
(a + b) / 2.
1029-
} else if abs_a < LO {
1030-
// Not safe to halve a
1031-
a + (b / 2.)
1032-
} else if abs_b < LO {
1033-
// Not safe to halve b
1034-
(a / 2.) + b
1035-
} else {
1036-
// Not safe to halve a and b
1037-
(a / 2.) + (b / 2.)
1019+
cfg_if! {
1020+
if #[cfg(any(
1021+
target_arch = "x86_64",
1022+
target_arch = "aarch64",
1023+
all(any(target_arch="riscv32", target_arch= "riscv64"), target_feature="d"),
1024+
all(target_arch = "arm", target_feature="vfp2"),
1025+
target_arch = "wasm32",
1026+
target_arch = "wasm64",
1027+
))] {
1028+
// whitelist the faster implementation to targets that have known good 64-bit float
1029+
// implementations. Falling back to the branchy code on targets that don't have
1030+
// 64-bit hardware floats or buggy implementations.
1031+
// see: https://github.com/rust-lang/rust/pull/121062#issuecomment-2123408114
1032+
((f64::from(self) + f64::from(other)) / 2.0) as f32
1033+
} else {
1034+
const LO: f32 = f32::MIN_POSITIVE * 2.;
1035+
const HI: f32 = f32::MAX / 2.;
1036+
1037+
let (a, b) = (self, other);
1038+
let abs_a = a.abs_private();
1039+
let abs_b = b.abs_private();
1040+
1041+
if abs_a <= HI && abs_b <= HI {
1042+
// Overflow is impossible
1043+
(a + b) / 2.
1044+
} else if abs_a < LO {
1045+
// Not safe to halve a
1046+
a + (b / 2.)
1047+
} else if abs_b < LO {
1048+
// Not safe to halve b
1049+
(a / 2.) + b
1050+
} else {
1051+
// Not safe to halve a and b
1052+
(a / 2.) + (b / 2.)
1053+
}
1054+
}
10381055
}
10391056
}
10401057

‎core/tests/num/mod.rs

+26-3
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ assume_usize_width! {
719719
}
720720

721721
macro_rules! test_float {
722-
($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr, $min: expr, $max: expr, $min_pos: expr) => {
722+
($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr, $min: expr, $max: expr, $min_pos: expr, $max_exp:expr) => {
723723
mod $modname {
724724
#[test]
725725
fn min() {
@@ -870,6 +870,27 @@ macro_rules! test_float {
870870
assert!(($nan as $fty).midpoint(1.0).is_nan());
871871
assert!((1.0 as $fty).midpoint($nan).is_nan());
872872
assert!(($nan as $fty).midpoint($nan).is_nan());
873+
874+
// test if large differences in magnitude are still correctly computed.
875+
// NOTE: that because of how small x and y are, x + y can never overflow
876+
// so (x + y) / 2.0 is always correct
877+
// in particular, `2.pow(i)` will never be at the max exponent, so it could
878+
// be safely doubled, while j is significantly smaller.
879+
for i in $max_exp.saturating_sub(64)..$max_exp {
880+
for j in 0..64u8 {
881+
let large = <$fty>::from(2.0f32).powi(i);
882+
// a much smaller number, such that there is no chance of overflow to test
883+
// potential double rounding in midpoint's implementation.
884+
let small = <$fty>::from(2.0f32).powi($max_exp - 1)
885+
* <$fty>::EPSILON
886+
* <$fty>::from(j);
887+
888+
let naive = (large + small) / 2.0;
889+
let midpoint = large.midpoint(small);
890+
891+
assert_eq!(naive, midpoint);
892+
}
893+
}
873894
}
874895
#[test]
875896
fn rem_euclid() {
@@ -902,7 +923,8 @@ test_float!(
902923
f32::NAN,
903924
f32::MIN,
904925
f32::MAX,
905-
f32::MIN_POSITIVE
926+
f32::MIN_POSITIVE,
927+
f32::MAX_EXP
906928
);
907929
test_float!(
908930
f64,
@@ -912,5 +934,6 @@ test_float!(
912934
f64::NAN,
913935
f64::MIN,
914936
f64::MAX,
915-
f64::MIN_POSITIVE
937+
f64::MIN_POSITIVE,
938+
f64::MAX_EXP
916939
);

0 commit comments

Comments
 (0)
Failed to load comments.