Skip to content

Commit

Permalink
implement simd_relaxed_fma
Browse files Browse the repository at this point in the history
  • Loading branch information
RalfJung committed Dec 4, 2024
1 parent efd1352 commit 16ee60a
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 33 deletions.
17 changes: 14 additions & 3 deletions src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use either::Either;
use rand::Rng;
use rustc_abi::{Endian, HasDataLayout};
use rustc_apfloat::{Float, Round};
use rustc_middle::ty::FloatTy;
Expand Down Expand Up @@ -286,7 +287,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
this.write_scalar(val, &dest)?;
}
}
"fma" => {
"fma" | "relaxed_fma" => {
let [a, b, c] = check_arg_count(args)?;
let (a, a_len) = this.project_to_simd(a)?;
let (b, b_len) = this.project_to_simd(b)?;
Expand All @@ -303,6 +304,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let c = this.read_scalar(&this.project_index(&c, i)?)?;
let dest = this.project_index(&dest, i)?;

let fuse: bool = intrinsic_name == "fma" || this.machine.rng.get_mut().gen();

// Works for f32 and f64.
// FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
let ty::Float(float_ty) = dest.layout.ty.kind() else {
Expand All @@ -314,15 +317,23 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let a = a.to_f32()?;
let b = b.to_f32()?;
let c = c.to_f32()?;
let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
let res = if fuse {
a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
} else {
((a * b).value + c).value
};
let res = this.adjust_nan(res, &[a, b, c]);
Scalar::from(res)
}
FloatTy::F64 => {
let a = a.to_f64()?;
let b = b.to_f64()?;
let c = c.to_f64()?;
let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
let res = if fuse {
a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
} else {
((a * b).value + c).value
};
let res = this.adjust_nan(res, &[a, b, c]);
Scalar::from(res)
}
Expand Down
91 changes: 61 additions & 30 deletions tests/pass/intrinsics/fmuladd_nondeterministic.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,75 @@
#![feature(core_intrinsics)]
#![feature(core_intrinsics, portable_simd)]
use std::intrinsics::simd::simd_relaxed_fma;
use std::intrinsics::{fmuladdf32, fmuladdf64};
use std::simd::prelude::*;

fn main() {
let mut saw_zero = false;
let mut saw_nonzero = false;
fn ensure_both_happen(f: impl Fn() -> bool) -> bool {
let mut saw_true = false;
let mut saw_false = false;
for _ in 0..50 {
let a = std::hint::black_box(0.1_f64);
let b = std::hint::black_box(0.2);
let c = std::hint::black_box(-a * b);
// It is unspecified whether the following operation is fused or not. The
// following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused.
let x = unsafe { fmuladdf64(a, b, c) };
if x == 0.0 {
saw_zero = true;
let b = f();
if b {
saw_true = true;
} else {
saw_nonzero = true;
saw_false = true;
}
if saw_true && saw_false {
return true;
}
}
false
}

fn main() {
assert!(
saw_zero && saw_nonzero,
ensure_both_happen(|| {
let a = std::hint::black_box(0.1_f64);
let b = std::hint::black_box(0.2);
let c = std::hint::black_box(-a * b);
// It is unspecified whether the following operation is fused or not. The
// following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused.
let x = unsafe { fmuladdf64(a, b, c) };
x == 0.0
}),
"`fmuladdf64` failed to be evaluated as both fused and unfused"
);

let mut saw_zero = false;
let mut saw_nonzero = false;
for _ in 0..50 {
let a = std::hint::black_box(0.1_f32);
let b = std::hint::black_box(0.2);
let c = std::hint::black_box(-a * b);
// It is unspecified whether the following operation is fused or not. The
// following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused.
let x = unsafe { fmuladdf32(a, b, c) };
if x == 0.0 {
saw_zero = true;
} else {
saw_nonzero = true;
}
}
assert!(
saw_zero && saw_nonzero,
ensure_both_happen(|| {
let a = std::hint::black_box(0.1_f32);
let b = std::hint::black_box(0.2);
let c = std::hint::black_box(-a * b);
// It is unspecified whether the following operation is fused or not. The
// following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused.
let x = unsafe { fmuladdf32(a, b, c) };
x == 0.0
}),
"`fmuladdf32` failed to be evaluated as both fused and unfused"
);

assert!(
ensure_both_happen(|| {
let a = f32x4::splat(std::hint::black_box(0.1));
let b = f32x4::splat(std::hint::black_box(0.2));
let c = std::hint::black_box(-a * b);
let x = unsafe { simd_relaxed_fma(a, b, c) };
// Whether we fuse or not is a per-element decision, so sometimes these should be
// the same and sometimes not.
x[0] == x[1]
}),
"`simd_relaxed_fma` failed to be evaluated as both fused and unfused"
);

assert!(
ensure_both_happen(|| {
let a = f64x4::splat(std::hint::black_box(0.1));
let b = f64x4::splat(std::hint::black_box(0.2));
let c = std::hint::black_box(-a * b);
let x = unsafe { simd_relaxed_fma(a, b, c) };
// Whether we fuse or not is a per-element decision, so sometimes these should be
// the same and sometimes not.
x[0] == x[1]
}),
"`simd_relaxed_fma` failed to be evaluated as both fused and unfused"
);
}
22 changes: 22 additions & 0 deletions tests/pass/intrinsics/portable-simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ fn simd_ops_f32() {
f32x4::splat(-3.2).mul_add(b, f32x4::splat(f32::NEG_INFINITY)),
f32x4::splat(f32::NEG_INFINITY)
);

unsafe {
assert_eq!(intrinsics::simd_relaxed_fma(a, b, a), (a * b) + a);
assert_eq!(intrinsics::simd_relaxed_fma(b, b, a), (b * b) + a);
assert_eq!(intrinsics::simd_relaxed_fma(a, b, b), (a * b) + b);
assert_eq!(
intrinsics::simd_relaxed_fma(f32x4::splat(-3.2), b, f32x4::splat(f32::NEG_INFINITY)),
f32x4::splat(f32::NEG_INFINITY)
);
}

assert_eq!((a * a).sqrt(), a);
assert_eq!((b * b).sqrt(), b.abs());

Expand Down Expand Up @@ -94,6 +105,17 @@ fn simd_ops_f64() {
f64x4::splat(-3.2).mul_add(b, f64x4::splat(f64::NEG_INFINITY)),
f64x4::splat(f64::NEG_INFINITY)
);

unsafe {
assert_eq!(intrinsics::simd_relaxed_fma(a, b, a), (a * b) + a);
assert_eq!(intrinsics::simd_relaxed_fma(b, b, a), (b * b) + a);
assert_eq!(intrinsics::simd_relaxed_fma(a, b, b), (a * b) + b);
assert_eq!(
intrinsics::simd_relaxed_fma(f64x4::splat(-3.2), b, f64x4::splat(f64::NEG_INFINITY)),
f64x4::splat(f64::NEG_INFINITY)
);
}

assert_eq!((a * a).sqrt(), a);
assert_eq!((b * b).sqrt(), b.abs());

Expand Down

0 comments on commit 16ee60a

Please sign in to comment.