From bc96882ade18c83fcf3bc053cce411d6e8cfa603 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eduardo=20S=C3=A1nchez=20Mu=C3=B1oz?= Date: Sat, 18 Nov 2023 19:39:51 +0100 Subject: [PATCH] Implement all 16 AVX compare operators `_mm_cmp_{ss,ps,sd,pd}` functions are AVX functions that use `llvm.x86.sse{,2}` prefixed intrinsics, so they were "accidentally" partially implemented when SSE and SSE2 intrinsics were implemented. The 16 AVX compare operators are now implemented and tested. --- src/shims/x86/mod.rs | 124 ++++++++++++----------- src/shims/x86/sse.rs | 22 +++-- src/shims/x86/sse2.rs | 20 ++-- tests/pass/intrinsics-x86-avx.rs | 162 +++++++++++++++++++++++++++++++ 4 files changed, 261 insertions(+), 67 deletions(-) create mode 100644 tests/pass/intrinsics-x86-avx.rs diff --git a/src/shims/x86/mod.rs b/src/shims/x86/mod.rs index d88a3127ec..8992b8184d 100644 --- a/src/shims/x86/mod.rs +++ b/src/shims/x86/mod.rs @@ -119,53 +119,27 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: } } -/// Floating point comparison operation -/// -/// -/// -/// -/// -#[derive(Copy, Clone)] -enum FloatCmpOp { - Eq, - Lt, - Le, - Unord, - Neq, - /// Not less-than - Nlt, - /// Not less-or-equal - Nle, - /// Ordered, i.e. neither of them is NaN - Ord, -} - -impl FloatCmpOp { - /// Convert from the `imm` argument used to specify the comparison - /// operation in intrinsics such as `llvm.x86.sse.cmp.ss`. - fn from_intrinsic_imm(imm: i8, intrinsic: &str) -> InterpResult<'_, Self> { - match imm { - 0 => Ok(Self::Eq), - 1 => Ok(Self::Lt), - 2 => Ok(Self::Le), - 3 => Ok(Self::Unord), - 4 => Ok(Self::Neq), - 5 => Ok(Self::Nlt), - 6 => Ok(Self::Nle), - 7 => Ok(Self::Ord), - imm => { - throw_unsup_format!("invalid `imm` parameter of {intrinsic}: {imm}"); - } - } - } -} - #[derive(Copy, Clone)] enum FloatBinOp { /// Arithmetic operation Arith(mir::BinOp), /// Comparison - Cmp(FloatCmpOp), + /// AVX supports all 16 combinations, SSE only a subset + /// + /// + /// + /// + /// + Cmp { + /// Result when lhs < rhs + gt: bool, + /// Result when lhs > rhs + lt: bool, + /// Result when lhs == rhs + eq: bool, + /// Result when lhs is NaN or rhs is NaN + unord: bool, + }, /// Minimum value (with SSE semantics) /// /// @@ -182,6 +156,53 @@ enum FloatBinOp { Max, } +impl FloatBinOp { + /// Convert from the `imm` argument used to specify the comparison + /// operation in intrinsics such as `llvm.x86.sse.cmp.ss`. + fn cmp_from_imm(imm: i8, intrinsic: &str) -> InterpResult<'_, Self> { + // Bit 4 specifies whether the operation is quiet or signaling, which + // we do not care in Miri. + match imm & !0x10 { + // EQ_O + 0x0 => Ok(Self::Cmp { gt: false, lt: false, eq: true, unord: false }), + // LT_O + 0x1 => Ok(Self::Cmp { gt: false, lt: true, eq: false, unord: false }), + // LE_O + 0x2 => Ok(Self::Cmp { gt: false, lt: true, eq: true, unord: false }), + // UNORD + 0x3 => Ok(Self::Cmp { gt: false, lt: false, eq: false, unord: true }), + // NEQ_U + 0x4 => Ok(Self::Cmp { gt: true, lt: true, eq: false, unord: true }), + // NLT_U + 0x5 => Ok(Self::Cmp { gt: true, lt: false, eq: true, unord: true }), + // NLE_U + 0x6 => Ok(Self::Cmp { gt: true, lt: false, eq: false, unord: true }), + // ORD + 0x7 => Ok(Self::Cmp { gt: true, lt: true, eq: true, unord: false }), + // The following are only accessible through stdarch AVX functions. + // EQ_U + 0x8 => Ok(Self::Cmp { gt: false, lt: false, eq: true, unord: true }), + // NGE_U + 0x9 => Ok(Self::Cmp { gt: false, lt: true, eq: false, unord: true }), + // NGT_U + 0xA => Ok(Self::Cmp { gt: false, lt: true, eq: true, unord: true }), + // FALSE_O + 0xB => Ok(Self::Cmp { gt: false, lt: false, eq: false, unord: false }), + // NEQ_O + 0xC => Ok(Self::Cmp { gt: true, lt: true, eq: false, unord: false }), + // GE_O + 0xD => Ok(Self::Cmp { gt: true, lt: false, eq: true, unord: false }), + // GE_O + 0xE => Ok(Self::Cmp { gt: true, lt: false, eq: false, unord: false }), + // TRUE_U + 0xF => Ok(Self::Cmp { gt: true, lt: true, eq: true, unord: true }), + imm => { + throw_unsup_format!("invalid `imm` parameter of {intrinsic}: 0x{imm:x}"); + } + } + } +} + /// Performs `which` scalar operation on `left` and `right` and returns /// the result. fn bin_op_float<'tcx, F: rustc_apfloat::Float>( @@ -195,20 +216,15 @@ fn bin_op_float<'tcx, F: rustc_apfloat::Float>( let res = this.wrapping_binary_op(which, left, right)?; Ok(res.to_scalar()) } - FloatBinOp::Cmp(which) => { + FloatBinOp::Cmp { gt, lt, eq, unord } => { let left = left.to_scalar().to_float::()?; let right = right.to_scalar().to_float::()?; - // FIXME: Make sure that these operations match the semantics - // of cmpps/cmpss/cmppd/cmpsd - let res = match which { - FloatCmpOp::Eq => left == right, - FloatCmpOp::Lt => left < right, - FloatCmpOp::Le => left <= right, - FloatCmpOp::Unord => left.is_nan() || right.is_nan(), - FloatCmpOp::Neq => left != right, - FloatCmpOp::Nlt => !(left < right), - FloatCmpOp::Nle => !(left <= right), - FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(), + + let res = match left.partial_cmp(&right) { + None => unord, + Some(std::cmp::Ordering::Less) => lt, + Some(std::cmp::Ordering::Equal) => eq, + Some(std::cmp::Ordering::Greater) => gt, }; Ok(bool_to_simd_element(res, Size::from_bits(F::BITS))) } diff --git a/src/shims/x86/sse.rs b/src/shims/x86/sse.rs index 831228b7a2..e15023c3c2 100644 --- a/src/shims/x86/sse.rs +++ b/src/shims/x86/sse.rs @@ -5,7 +5,7 @@ use rustc_target::spec::abi::Abi; use rand::Rng as _; -use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp}; +use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp}; use crate::*; use shims::foreign_items::EmulateForeignItemResult; @@ -95,33 +95,41 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: unary_op_ps(this, which, op, dest)?; } - // Used to implement the _mm_cmp_ss function. + // Used to implement the _mm_cmp*_ss functions. // Performs a comparison operation on the first component of `left` // and `right`, returning 0 if false or `u32::MAX` if true. The remaining // components are copied from `left`. + // _mm_cmp_ss is actually an AVX function where the operation is specified + // by a const parameter. + // _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_ss are SSE functions + // with hard-coded operations. "cmp.ss" => { let [left, right, imm] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; - let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm( + let which = FloatBinOp::cmp_from_imm( this.read_scalar(imm)?.to_i8()?, "llvm.x86.sse.cmp.ss", - )?); + )?; bin_op_simd_float_first::(this, which, left, right, dest)?; } - // Used to implement the _mm_cmp_ps function. + // Used to implement the _mm_cmp*_ps functions. // Performs a comparison operation on each component of `left` // and `right`. For each component, returns 0 if false or u32::MAX // if true. + // _mm_cmp_ps is actually an AVX function where the operation is specified + // by a const parameter. + // _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_ps are SSE functions + // with hard-coded operations. "cmp.ps" => { let [left, right, imm] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; - let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm( + let which = FloatBinOp::cmp_from_imm( this.read_scalar(imm)?.to_i8()?, "llvm.x86.sse.cmp.ps", - )?); + )?; bin_op_simd_float_all::(this, which, left, right, dest)?; } diff --git a/src/shims/x86/sse2.rs b/src/shims/x86/sse2.rs index 3f2b9f5f0a..55520771cf 100644 --- a/src/shims/x86/sse2.rs +++ b/src/shims/x86/sse2.rs @@ -4,7 +4,7 @@ use rustc_middle::ty::Ty; use rustc_span::Symbol; use rustc_target::spec::abi::Abi; -use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp}; +use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp}; use crate::*; use shims::foreign_items::EmulateForeignItemResult; @@ -461,18 +461,22 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: this.write_scalar(res, &dest)?; } } - // Used to implement the _mm_cmp*_sd function. + // Used to implement the _mm_cmp*_sd functions. // Performs a comparison operation on the first component of `left` // and `right`, returning 0 if false or `u64::MAX` if true. The remaining // components are copied from `left`. + // _mm_cmp_sd is actually an AVX function where the operation is specified + // by a const parameter. + // _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_sd are SSE2 functions + // with hard-coded operations. "cmp.sd" => { let [left, right, imm] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; - let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm( + let which = FloatBinOp::cmp_from_imm( this.read_scalar(imm)?.to_i8()?, "llvm.x86.sse2.cmp.sd", - )?); + )?; bin_op_simd_float_first::(this, which, left, right, dest)?; } @@ -480,14 +484,18 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: // Performs a comparison operation on each component of `left` // and `right`. For each component, returns 0 if false or `u64::MAX` // if true. + // _mm_cmp_pd is actually an AVX function where the operation is specified + // by a const parameter. + // _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_pd are SSE2 functions + // with hard-coded operations. "cmp.pd" => { let [left, right, imm] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; - let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm( + let which = FloatBinOp::cmp_from_imm( this.read_scalar(imm)?.to_i8()?, "llvm.x86.sse2.cmp.pd", - )?); + )?; bin_op_simd_float_all::(this, which, left, right, dest)?; } diff --git a/tests/pass/intrinsics-x86-avx.rs b/tests/pass/intrinsics-x86-avx.rs new file mode 100644 index 0000000000..d5265cf217 --- /dev/null +++ b/tests/pass/intrinsics-x86-avx.rs @@ -0,0 +1,162 @@ +// Ignore everything except x86 and x86_64 +// Any additional target are added to CI should be ignored here +// (We cannot use `cfg`-based tricks here since the `target-feature` flags below only work on x86.) +//@ignore-target-aarch64 +//@ignore-target-arm +//@ignore-target-avr +//@ignore-target-s390x +//@ignore-target-thumbv7em +//@ignore-target-wasm32 +//@compile-flags: -C target-feature=+avx + +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; +use std::mem::transmute; + +fn main() { + assert!(is_x86_feature_detected!("avx")); + + unsafe { + test_avx(); + } +} + +#[target_feature(enable = "avx")] +unsafe fn test_avx() { + fn expected_cmp(imm: i32, lhs: F, rhs: F, if_t: F, if_f: F) -> F { + let res = match imm { + _CMP_EQ_OQ => lhs == rhs, + _CMP_LT_OS => lhs < rhs, + _CMP_LE_OS => lhs <= rhs, + _CMP_UNORD_Q => lhs.partial_cmp(&rhs).is_none(), + _CMP_NEQ_UQ => lhs != rhs, + _CMP_NLT_UQ => !(lhs < rhs), + _CMP_NLE_UQ => !(lhs <= rhs), + _CMP_ORD_Q => lhs.partial_cmp(&rhs).is_some(), + _CMP_EQ_UQ => lhs == rhs || lhs.partial_cmp(&rhs).is_none(), + _CMP_NGE_US => !(lhs >= rhs), + _CMP_NGT_US => !(lhs > rhs), + _CMP_FALSE_OQ => false, + _CMP_NEQ_OQ => lhs != rhs && lhs.partial_cmp(&rhs).is_some(), + _CMP_GE_OS => lhs >= rhs, + _CMP_GT_OS => lhs > rhs, + _CMP_TRUE_US => true, + _ => unreachable!(), + }; + if res { if_t } else { if_f } + } + fn expected_cmp_f32(imm: i32, lhs: f32, rhs: f32) -> f32 { + expected_cmp(imm, lhs, rhs, f32::from_bits(u32::MAX), 0.0) + } + fn expected_cmp_f64(imm: i32, lhs: f64, rhs: f64) -> f64 { + expected_cmp(imm, lhs, rhs, f64::from_bits(u64::MAX), 0.0) + } + + #[target_feature(enable = "avx")] + unsafe fn test_mm_cmp_ss() { + let values = [ + (1.0, 1.0), + (0.0, 1.0), + (1.0, 0.0), + (f32::NAN, 0.0), + (0.0, f32::NAN), + (f32::NAN, f32::NAN), + ]; + + for (lhs, rhs) in values { + let a = _mm_setr_ps(lhs, 2.0, 3.0, 4.0); + let b = _mm_setr_ps(rhs, 5.0, 6.0, 7.0); + let r: [u32; 4] = transmute(_mm_cmp_ss::(a, b)); + let e: [u32; 4] = + transmute(_mm_setr_ps(expected_cmp_f32(IMM, lhs, rhs), 2.0, 3.0, 4.0)); + assert_eq!(r, e); + } + } + + #[target_feature(enable = "avx")] + unsafe fn test_mm_cmp_ps() { + let values = [ + (1.0, 1.0), + (0.0, 1.0), + (1.0, 0.0), + (f32::NAN, 0.0), + (0.0, f32::NAN), + (f32::NAN, f32::NAN), + ]; + + for (lhs, rhs) in values { + let a = _mm_set1_ps(lhs); + let b = _mm_set1_ps(rhs); + let r: [u32; 4] = transmute(_mm_cmp_ps::(a, b)); + let e: [u32; 4] = transmute(_mm_set1_ps(expected_cmp_f32(IMM, lhs, rhs))); + assert_eq!(r, e); + } + } + + #[target_feature(enable = "avx")] + unsafe fn test_mm_cmp_sd() { + let values = [ + (1.0, 1.0), + (0.0, 1.0), + (1.0, 0.0), + (f64::NAN, 0.0), + (0.0, f64::NAN), + (f64::NAN, f64::NAN), + ]; + + for (lhs, rhs) in values { + let a = _mm_setr_pd(lhs, 2.0); + let b = _mm_setr_pd(rhs, 3.0); + let r: [u64; 2] = transmute(_mm_cmp_sd::(a, b)); + let e: [u64; 2] = transmute(_mm_setr_pd(expected_cmp_f64(IMM, lhs, rhs), 2.0)); + assert_eq!(r, e); + } + } + + #[target_feature(enable = "avx")] + unsafe fn test_mm_cmp_pd() { + let values = [ + (1.0, 1.0), + (0.0, 1.0), + (1.0, 0.0), + (f64::NAN, 0.0), + (0.0, f64::NAN), + (f64::NAN, f64::NAN), + ]; + + for (lhs, rhs) in values { + let a = _mm_set1_pd(lhs); + let b = _mm_set1_pd(rhs); + let r: [u64; 2] = transmute(_mm_cmp_pd::(a, b)); + let e: [u64; 2] = transmute(_mm_set1_pd(expected_cmp_f64(IMM, lhs, rhs))); + assert_eq!(r, e); + } + } + + #[target_feature(enable = "avx")] + unsafe fn test_cmp() { + test_mm_cmp_ss::(); + test_mm_cmp_ps::(); + test_mm_cmp_sd::(); + test_mm_cmp_pd::(); + } + + test_cmp::<_CMP_EQ_OQ>(); + test_cmp::<_CMP_LT_OS>(); + test_cmp::<_CMP_LE_OS>(); + test_cmp::<_CMP_UNORD_Q>(); + test_cmp::<_CMP_NEQ_UQ>(); + test_cmp::<_CMP_NLT_UQ>(); + test_cmp::<_CMP_NLE_UQ>(); + test_cmp::<_CMP_ORD_Q>(); + test_cmp::<_CMP_EQ_UQ>(); + test_cmp::<_CMP_NGE_US>(); + test_cmp::<_CMP_NGT_US>(); + test_cmp::<_CMP_FALSE_OQ>(); + test_cmp::<_CMP_NEQ_OQ>(); + test_cmp::<_CMP_GE_OS>(); + test_cmp::<_CMP_GT_OS>(); + test_cmp::<_CMP_TRUE_US>(); +}