Skip to content

Commit

Permalink
Auto merge of #3176 - eduardosm:cmp, r=RalfJung
Browse files Browse the repository at this point in the history
Implement all 16 AVX compare operators for 128-bit SIMD vectors

`_mm_cmp_{ss,ps,sd,pd}` functions are AVX functions that use `llvm.x86.sse{,2}.` prefixed intrinsics, so they were "accidentally" partially implemented when SSE and SSE2 intrinsics were implemented.

The 16 AVX compare operators are now implemented and tested.
  • Loading branch information
bors committed Nov 20, 2023
2 parents c71353d + 477e2fc commit d251208
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 71 deletions.
120 changes: 66 additions & 54 deletions src/shims/x86/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,53 +119,32 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
}
}

/// Floating point comparison operation
///
/// <https://www.felixcloutier.com/x86/cmpss>
/// <https://www.felixcloutier.com/x86/cmpps>
/// <https://www.felixcloutier.com/x86/cmpsd>
/// <https://www.felixcloutier.com/x86/cmppd>
#[derive(Copy, Clone)]
enum FloatCmpOp {
Eq,
Lt,
Le,
Unord,
Neq,
/// Not less-than
Nlt,
/// Not less-or-equal
Nle,
/// Ordered, i.e. neither of them is NaN
Ord,
}

impl FloatCmpOp {
/// Convert from the `imm` argument used to specify the comparison
/// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
fn from_intrinsic_imm(imm: i8, intrinsic: &str) -> InterpResult<'_, Self> {
match imm {
0 => Ok(Self::Eq),
1 => Ok(Self::Lt),
2 => Ok(Self::Le),
3 => Ok(Self::Unord),
4 => Ok(Self::Neq),
5 => Ok(Self::Nlt),
6 => Ok(Self::Nle),
7 => Ok(Self::Ord),
imm => {
throw_unsup_format!("invalid `imm` parameter of {intrinsic}: {imm}");
}
}
}
}

#[derive(Copy, Clone)]
enum FloatBinOp {
/// Arithmetic operation
Arith(mir::BinOp),
/// Comparison
Cmp(FloatCmpOp),
///
/// The semantics of this operator is a case distinction: we compare the two operands,
/// and then we return one of the four booleans `gt`, `lt`, `eq`, `unord` depending on
/// which class they fall into.
///
/// AVX supports all 16 combinations, SSE only a subset
///
/// <https://www.felixcloutier.com/x86/cmpss>
/// <https://www.felixcloutier.com/x86/cmpps>
/// <https://www.felixcloutier.com/x86/cmpsd>
/// <https://www.felixcloutier.com/x86/cmppd>
Cmp {
/// Result when lhs < rhs
gt: bool,
/// Result when lhs > rhs
lt: bool,
/// Result when lhs == rhs
eq: bool,
/// Result when lhs is NaN or rhs is NaN
unord: bool,
},
/// Minimum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/minss>
Expand All @@ -182,6 +161,44 @@ enum FloatBinOp {
Max,
}

impl FloatBinOp {
/// Convert from the `imm` argument used to specify the comparison
/// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
fn cmp_from_imm(imm: i8, intrinsic: &str) -> InterpResult<'_, Self> {
// Only bits 0..=4 are used, remaining should be zero.
if imm & !0b1_1111 != 0 {
throw_unsup_format!("invalid `imm` parameter of {intrinsic}: 0x{imm:x}");
}
// Bit 4 specifies whether the operation is quiet or signaling, which
// we do not care in Miri.
// Bits 0..=2 specifies the operation.
// `gt` indicates the result to be returned when the LHS is strictly
// greater than the RHS, and so on.
let (gt, lt, eq, unord) = match imm & 0b111 {
// Equal
0x0 => (false, false, true, false),
// Less-than
0x1 => (false, true, false, false),
// Less-or-equal
0x2 => (false, true, true, false),
// Unordered (either is NaN)
0x3 => (false, false, false, true),
// Not equal
0x4 => (true, true, false, true),
// Not less-than
0x5 => (true, false, true, true),
// Not less-or-equal
0x6 => (true, false, false, true),
// Ordered (neither is NaN)
0x7 => (true, true, true, false),
_ => unreachable!(),
};
// When bit 3 is 1 (only possible in AVX), unord is toggled.
let unord = unord ^ (imm & 0b1000 != 0);
Ok(Self::Cmp { gt, lt, eq, unord })
}
}

/// Performs `which` scalar operation on `left` and `right` and returns
/// the result.
fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
Expand All @@ -195,20 +212,15 @@ fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
let res = this.wrapping_binary_op(which, left, right)?;
Ok(res.to_scalar())
}
FloatBinOp::Cmp(which) => {
FloatBinOp::Cmp { gt, lt, eq, unord } => {
let left = left.to_scalar().to_float::<F>()?;
let right = right.to_scalar().to_float::<F>()?;
// FIXME: Make sure that these operations match the semantics
// of cmpps/cmpss/cmppd/cmpsd
let res = match which {
FloatCmpOp::Eq => left == right,
FloatCmpOp::Lt => left < right,
FloatCmpOp::Le => left <= right,
FloatCmpOp::Unord => left.is_nan() || right.is_nan(),
FloatCmpOp::Neq => left != right,
FloatCmpOp::Nlt => !(left < right),
FloatCmpOp::Nle => !(left <= right),
FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(),

let res = match left.partial_cmp(&right) {
None => unord,
Some(std::cmp::Ordering::Less) => lt,
Some(std::cmp::Ordering::Equal) => eq,
Some(std::cmp::Ordering::Greater) => gt,
};
Ok(bool_to_simd_element(res, Size::from_bits(F::BITS)))
}
Expand Down
22 changes: 15 additions & 7 deletions src/shims/x86/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use rustc_target::spec::abi::Abi;

use rand::Rng as _;

use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp};
use crate::*;
use shims::foreign_items::EmulateForeignItemResult;

Expand Down Expand Up @@ -95,33 +95,41 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:

unary_op_ps(this, which, op, dest)?;
}
// Used to implement the _mm_cmp_ss function.
// Used to implement the _mm_cmp*_ss functions.
// Performs a comparison operation on the first component of `left`
// and `right`, returning 0 if false or `u32::MAX` if true. The remaining
// components are copied from `left`.
// _mm_cmp_ss is actually an AVX function where the operation is specified
// by a const parameter.
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_ss are SSE functions
// with hard-coded operations.
"cmp.ss" => {
let [left, right, imm] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
let which = FloatBinOp::cmp_from_imm(
this.read_scalar(imm)?.to_i8()?,
"llvm.x86.sse.cmp.ss",
)?);
)?;

bin_op_simd_float_first::<Single>(this, which, left, right, dest)?;
}
// Used to implement the _mm_cmp_ps function.
// Used to implement the _mm_cmp*_ps functions.
// Performs a comparison operation on each component of `left`
// and `right`. For each component, returns 0 if false or u32::MAX
// if true.
// _mm_cmp_ps is actually an AVX function where the operation is specified
// by a const parameter.
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_ps are SSE functions
// with hard-coded operations.
"cmp.ps" => {
let [left, right, imm] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
let which = FloatBinOp::cmp_from_imm(
this.read_scalar(imm)?.to_i8()?,
"llvm.x86.sse.cmp.ps",
)?);
)?;

bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
}
Expand Down
20 changes: 14 additions & 6 deletions src/shims/x86/sse2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rustc_middle::ty::Ty;
use rustc_span::Symbol;
use rustc_target::spec::abi::Abi;

use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp};
use crate::*;
use shims::foreign_items::EmulateForeignItemResult;

Expand Down Expand Up @@ -461,33 +461,41 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
this.write_scalar(res, &dest)?;
}
}
// Used to implement the _mm_cmp*_sd function.
// Used to implement the _mm_cmp*_sd functions.
// Performs a comparison operation on the first component of `left`
// and `right`, returning 0 if false or `u64::MAX` if true. The remaining
// components are copied from `left`.
// _mm_cmp_sd is actually an AVX function where the operation is specified
// by a const parameter.
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_sd are SSE2 functions
// with hard-coded operations.
"cmp.sd" => {
let [left, right, imm] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
let which = FloatBinOp::cmp_from_imm(
this.read_scalar(imm)?.to_i8()?,
"llvm.x86.sse2.cmp.sd",
)?);
)?;

bin_op_simd_float_first::<Double>(this, which, left, right, dest)?;
}
// Used to implement the _mm_cmp*_pd functions.
// Performs a comparison operation on each component of `left`
// and `right`. For each component, returns 0 if false or `u64::MAX`
// if true.
// _mm_cmp_pd is actually an AVX function where the operation is specified
// by a const parameter.
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_pd are SSE2 functions
// with hard-coded operations.
"cmp.pd" => {
let [left, right, imm] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
let which = FloatBinOp::cmp_from_imm(
this.read_scalar(imm)?.to_i8()?,
"llvm.x86.sse2.cmp.pd",
)?);
)?;

bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
}
Expand Down
2 changes: 1 addition & 1 deletion tests/pass/intrinsics-x86-aes-vaes.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Ignore everything except x86 and x86_64
// Any additional target are added to CI should be ignored here
// Any new targets that are added to CI should be ignored here.
// (We cannot use `cfg`-based tricks here since the `target-feature` flags below only work on x86.)
//@ignore-target-aarch64
//@ignore-target-arm
Expand Down
Loading

0 comments on commit d251208

Please sign in to comment.