Skip to content

Commit

Permalink
Implement all 16 AVX compare operators
Browse files Browse the repository at this point in the history
`_mm_cmp_{ss,ps,sd,pd}` functions are AVX functions that use `llvm.x86.sse{,2}` prefixed intrinsics, so they were "accidentally" partially implemented when SSE and SSE2 intrinsics were implemented.

The 16 AVX compare operators are now implemented and tested.
  • Loading branch information
eduardosm committed Nov 18, 2023
1 parent 02baccc commit bc96882
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 67 deletions.
124 changes: 70 additions & 54 deletions src/shims/x86/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,53 +119,27 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
}
}

/// Floating point comparison operation
///
/// <https://www.felixcloutier.com/x86/cmpss>
/// <https://www.felixcloutier.com/x86/cmpps>
/// <https://www.felixcloutier.com/x86/cmpsd>
/// <https://www.felixcloutier.com/x86/cmppd>
#[derive(Copy, Clone)]
enum FloatCmpOp {
Eq,
Lt,
Le,
Unord,
Neq,
/// Not less-than
Nlt,
/// Not less-or-equal
Nle,
/// Ordered, i.e. neither of them is NaN
Ord,
}

impl FloatCmpOp {
/// Convert from the `imm` argument used to specify the comparison
/// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
fn from_intrinsic_imm(imm: i8, intrinsic: &str) -> InterpResult<'_, Self> {
match imm {
0 => Ok(Self::Eq),
1 => Ok(Self::Lt),
2 => Ok(Self::Le),
3 => Ok(Self::Unord),
4 => Ok(Self::Neq),
5 => Ok(Self::Nlt),
6 => Ok(Self::Nle),
7 => Ok(Self::Ord),
imm => {
throw_unsup_format!("invalid `imm` parameter of {intrinsic}: {imm}");
}
}
}
}

#[derive(Copy, Clone)]
enum FloatBinOp {
/// Arithmetic operation
Arith(mir::BinOp),
/// Comparison
Cmp(FloatCmpOp),
/// AVX supports all 16 combinations, SSE only a subset
///
/// <https://www.felixcloutier.com/x86/cmpss>
/// <https://www.felixcloutier.com/x86/cmpps>
/// <https://www.felixcloutier.com/x86/cmpsd>
/// <https://www.felixcloutier.com/x86/cmppd>
Cmp {
/// Result when lhs < rhs
gt: bool,
/// Result when lhs > rhs
lt: bool,
/// Result when lhs == rhs
eq: bool,
/// Result when lhs is NaN or rhs is NaN
unord: bool,
},
/// Minimum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/minss>
Expand All @@ -182,6 +156,53 @@ enum FloatBinOp {
Max,
}

impl FloatBinOp {
/// Convert from the `imm` argument used to specify the comparison
/// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
fn cmp_from_imm(imm: i8, intrinsic: &str) -> InterpResult<'_, Self> {
// Bit 4 specifies whether the operation is quiet or signaling, which
// we do not care in Miri.
match imm & !0x10 {
// EQ_O
0x0 => Ok(Self::Cmp { gt: false, lt: false, eq: true, unord: false }),
// LT_O
0x1 => Ok(Self::Cmp { gt: false, lt: true, eq: false, unord: false }),
// LE_O
0x2 => Ok(Self::Cmp { gt: false, lt: true, eq: true, unord: false }),
// UNORD
0x3 => Ok(Self::Cmp { gt: false, lt: false, eq: false, unord: true }),
// NEQ_U
0x4 => Ok(Self::Cmp { gt: true, lt: true, eq: false, unord: true }),
// NLT_U
0x5 => Ok(Self::Cmp { gt: true, lt: false, eq: true, unord: true }),
// NLE_U
0x6 => Ok(Self::Cmp { gt: true, lt: false, eq: false, unord: true }),
// ORD
0x7 => Ok(Self::Cmp { gt: true, lt: true, eq: true, unord: false }),
// The following are only accessible through stdarch AVX functions.
// EQ_U
0x8 => Ok(Self::Cmp { gt: false, lt: false, eq: true, unord: true }),
// NGE_U
0x9 => Ok(Self::Cmp { gt: false, lt: true, eq: false, unord: true }),
// NGT_U
0xA => Ok(Self::Cmp { gt: false, lt: true, eq: true, unord: true }),
// FALSE_O
0xB => Ok(Self::Cmp { gt: false, lt: false, eq: false, unord: false }),
// NEQ_O
0xC => Ok(Self::Cmp { gt: true, lt: true, eq: false, unord: false }),
// GE_O
0xD => Ok(Self::Cmp { gt: true, lt: false, eq: true, unord: false }),
// GE_O
0xE => Ok(Self::Cmp { gt: true, lt: false, eq: false, unord: false }),
// TRUE_U
0xF => Ok(Self::Cmp { gt: true, lt: true, eq: true, unord: true }),
imm => {
throw_unsup_format!("invalid `imm` parameter of {intrinsic}: 0x{imm:x}");
}
}
}
}

/// Performs `which` scalar operation on `left` and `right` and returns
/// the result.
fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
Expand All @@ -195,20 +216,15 @@ fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
let res = this.wrapping_binary_op(which, left, right)?;
Ok(res.to_scalar())
}
FloatBinOp::Cmp(which) => {
FloatBinOp::Cmp { gt, lt, eq, unord } => {
let left = left.to_scalar().to_float::<F>()?;
let right = right.to_scalar().to_float::<F>()?;
// FIXME: Make sure that these operations match the semantics
// of cmpps/cmpss/cmppd/cmpsd
let res = match which {
FloatCmpOp::Eq => left == right,
FloatCmpOp::Lt => left < right,
FloatCmpOp::Le => left <= right,
FloatCmpOp::Unord => left.is_nan() || right.is_nan(),
FloatCmpOp::Neq => left != right,
FloatCmpOp::Nlt => !(left < right),
FloatCmpOp::Nle => !(left <= right),
FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(),

let res = match left.partial_cmp(&right) {
None => unord,
Some(std::cmp::Ordering::Less) => lt,
Some(std::cmp::Ordering::Equal) => eq,
Some(std::cmp::Ordering::Greater) => gt,
};
Ok(bool_to_simd_element(res, Size::from_bits(F::BITS)))
}
Expand Down
22 changes: 15 additions & 7 deletions src/shims/x86/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use rustc_target::spec::abi::Abi;

use rand::Rng as _;

use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp};
use crate::*;
use shims::foreign_items::EmulateForeignItemResult;

Expand Down Expand Up @@ -95,33 +95,41 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:

unary_op_ps(this, which, op, dest)?;
}
// Used to implement the _mm_cmp_ss function.
// Used to implement the _mm_cmp*_ss functions.
// Performs a comparison operation on the first component of `left`
// and `right`, returning 0 if false or `u32::MAX` if true. The remaining
// components are copied from `left`.
// _mm_cmp_ss is actually an AVX function where the operation is specified
// by a const parameter.
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_ss are SSE functions
// with hard-coded operations.
"cmp.ss" => {
let [left, right, imm] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
let which = FloatBinOp::cmp_from_imm(
this.read_scalar(imm)?.to_i8()?,
"llvm.x86.sse.cmp.ss",
)?);
)?;

bin_op_simd_float_first::<Single>(this, which, left, right, dest)?;
}
// Used to implement the _mm_cmp_ps function.
// Used to implement the _mm_cmp*_ps functions.
// Performs a comparison operation on each component of `left`
// and `right`. For each component, returns 0 if false or u32::MAX
// if true.
// _mm_cmp_ps is actually an AVX function where the operation is specified
// by a const parameter.
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_ps are SSE functions
// with hard-coded operations.
"cmp.ps" => {
let [left, right, imm] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
let which = FloatBinOp::cmp_from_imm(
this.read_scalar(imm)?.to_i8()?,
"llvm.x86.sse.cmp.ps",
)?);
)?;

bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
}
Expand Down
20 changes: 14 additions & 6 deletions src/shims/x86/sse2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rustc_middle::ty::Ty;
use rustc_span::Symbol;
use rustc_target::spec::abi::Abi;

use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp};
use crate::*;
use shims::foreign_items::EmulateForeignItemResult;

Expand Down Expand Up @@ -461,33 +461,41 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
this.write_scalar(res, &dest)?;
}
}
// Used to implement the _mm_cmp*_sd function.
// Used to implement the _mm_cmp*_sd functions.
// Performs a comparison operation on the first component of `left`
// and `right`, returning 0 if false or `u64::MAX` if true. The remaining
// components are copied from `left`.
// _mm_cmp_sd is actually an AVX function where the operation is specified
// by a const parameter.
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_sd are SSE2 functions
// with hard-coded operations.
"cmp.sd" => {
let [left, right, imm] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
let which = FloatBinOp::cmp_from_imm(
this.read_scalar(imm)?.to_i8()?,
"llvm.x86.sse2.cmp.sd",
)?);
)?;

bin_op_simd_float_first::<Double>(this, which, left, right, dest)?;
}
// Used to implement the _mm_cmp*_pd functions.
// Performs a comparison operation on each component of `left`
// and `right`. For each component, returns 0 if false or `u64::MAX`
// if true.
// _mm_cmp_pd is actually an AVX function where the operation is specified
// by a const parameter.
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_pd are SSE2 functions
// with hard-coded operations.
"cmp.pd" => {
let [left, right, imm] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
let which = FloatBinOp::cmp_from_imm(
this.read_scalar(imm)?.to_i8()?,
"llvm.x86.sse2.cmp.pd",
)?);
)?;

bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
}
Expand Down
Loading

0 comments on commit bc96882

Please sign in to comment.