From f4e6cc88889c87a7774e0b6cbb486b5a5ba12ed3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eduardo=20S=C3=A1nchez=20Mu=C3=B1oz?= Date: Tue, 10 Oct 2023 20:38:33 +0200 Subject: [PATCH] Implement `llvm.x86.sse41.*` intrinsics --- src/shims/x86/mod.rs | 6 + src/shims/x86/sse41.rs | 312 +++++++++++++++++++++++++++++ tests/pass/intrinsics-x86-sse41.rs | 252 +++++++++++++++++++++++ 3 files changed, 570 insertions(+) create mode 100644 src/shims/x86/sse41.rs create mode 100644 tests/pass/intrinsics-x86-sse41.rs diff --git a/src/shims/x86/mod.rs b/src/shims/x86/mod.rs index 394c955e4c..d88a3127ec 100644 --- a/src/shims/x86/mod.rs +++ b/src/shims/x86/mod.rs @@ -11,6 +11,7 @@ mod aesni; mod sse; mod sse2; mod sse3; +mod sse41; mod ssse3; impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {} @@ -101,6 +102,11 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: this, link_name, abi, args, dest, ); } + name if name.starts_with("sse41.") => { + return sse41::EvalContextExt::emulate_x86_sse41_intrinsic( + this, link_name, abi, args, dest, + ); + } name if name.starts_with("aesni.") => { return aesni::EvalContextExt::emulate_x86_aesni_intrinsic( this, link_name, abi, args, dest, diff --git a/src/shims/x86/sse41.rs b/src/shims/x86/sse41.rs new file mode 100644 index 0000000000..f91247edb3 --- /dev/null +++ b/src/shims/x86/sse41.rs @@ -0,0 +1,312 @@ +use rustc_apfloat::Float as _; +use rustc_middle::mir; +use rustc_span::Symbol; +use rustc_target::spec::abi::Abi; + +use crate::*; +use shims::foreign_items::EmulateForeignItemResult; + +impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {} +pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: + crate::MiriInterpCxExt<'mir, 'tcx> +{ + fn emulate_x86_sse41_intrinsic( + &mut self, + link_name: Symbol, + abi: Abi, + args: &[OpTy<'tcx, Provenance>], + dest: &PlaceTy<'tcx, Provenance>, + ) -> InterpResult<'tcx, EmulateForeignItemResult> { + let this = self.eval_context_mut(); + // Prefix should have already been checked. + let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.sse41.").unwrap(); + + match unprefixed_name { + // Used to implement the _mm_insert_ps function. + // Takes one element of `right` and inserts it into `left` and + // optionally zero some elements. Source index, destination index + // and zeroed indices are specified by `imm`. + "insertps" => { + let [left, right, imm] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let (left, left_len) = this.operand_to_simd(left)?; + let (right, right_len) = this.operand_to_simd(right)?; + let (dest, dest_len) = this.place_to_simd(dest)?; + + assert_eq!(dest_len, left_len); + assert_eq!(dest_len, right_len); + + let imm = this.read_scalar(imm)?.to_u8()?; + let src_index = u64::from((imm >> 6) & 0b11); + let dst_index = u64::from((imm >> 4) & 0b11); + + for i in 0..dest_len { + let dest = this.project_index(&dest, i)?; + + if imm & (1 << i) != 0 { + // zeroed + this.write_scalar(Scalar::from_u32(0), &dest)?; + } else if i == dst_index { + // copy from `right` + this.copy_op( + &this.project_index(&right, src_index)?, + &dest, + /*allow_transmute*/ false, + )?; + } else { + // copy from `left` + this.copy_op( + &this.project_index(&left, i)?, + &dest, + /*allow_transmute*/ false, + )?; + } + } + } + // Used to implement the _mm_packus_epi32 function. + // Converts two 32-bit signed integer vectors to a single 16-bit + // unsigned integer vector with saturation. + "packusdw" => { + let [left, right] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let (left, left_len) = this.operand_to_simd(left)?; + let (right, right_len) = this.operand_to_simd(right)?; + let (dest, dest_len) = this.place_to_simd(dest)?; + + assert_eq!(left_len, right_len); + assert_eq!(dest_len, left_len.checked_mul(2).unwrap()); + + for i in 0..left_len { + let left = this.read_scalar(&this.project_index(&left, i)?)?.to_i32()?; + let right = this.read_scalar(&this.project_index(&right, i)?)?.to_i32()?; + let left_dest = this.project_index(&dest, i)?; + let right_dest = this.project_index(&dest, i.checked_add(left_len).unwrap())?; + + let left_res = + u16::try_from(left).unwrap_or(if left < 0 { 0 } else { u16::MAX }); + let right_res = + u16::try_from(right).unwrap_or(if right < 0 { 0 } else { u16::MAX }); + + this.write_scalar(Scalar::from_u16(left_res), &left_dest)?; + this.write_scalar(Scalar::from_u16(right_res), &right_dest)?; + } + } + // Used to implement the _mm_dp_ps and _mm_dp_pd functions. + // Conditionally multiplies the packed floating-point elements in + // `left` and `right` using the high 4 bits in `imm`, sums the four + // products, and conditionally stores the sum in `dest` using the low + // 4 bits of `imm`. + "dpps" | "dppd" => { + let [left, right, imm] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let (left, left_len) = this.operand_to_simd(left)?; + let (right, right_len) = this.operand_to_simd(right)?; + let (dest, dest_len) = this.place_to_simd(dest)?; + + assert_eq!(left_len, right_len); + + let imm = this.read_scalar(imm)?.to_u8()?; + + let element_layout = left.layout.field(this, 0); + + // Calculate dot product + // Elements are floating point number, but we can use `from_int` + // because the representation of 0.0 is all zero bits. + let mut sum = ImmTy::from_int(0u8, element_layout); + for i in 0..left_len { + if imm & (1 << i.checked_add(4).unwrap()) != 0 { + let left = this.read_immediate(&this.project_index(&left, i)?)?; + let right = this.read_immediate(&this.project_index(&right, i)?)?; + + let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?; + sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?; + } + } + + // Write to destination (conditioned to imm) + for i in 0..dest_len { + let dest = this.project_index(&dest, i)?; + + if imm & (1 << i) != 0 { + this.write_immediate(*sum, &dest)?; + } else { + this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?; + } + } + } + // Used to implement the _mm_floor_ss, _mm_ceil_ss and _mm_round_ss + // functions. Rounds the first element of `right` according to `rounding` + // and copies the remaining from `left`. + "round.ss" => { + let [left, right, rounding] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let (left, left_len) = this.operand_to_simd(left)?; + let (right, right_len) = this.operand_to_simd(right)?; + let (dest, dest_len) = this.place_to_simd(dest)?; + + assert_eq!(dest_len, left_len); + assert_eq!(dest_len, right_len); + + let rounding = match this.read_scalar(rounding)?.to_i32()? & !0x80 { + 0x00 => rustc_apfloat::Round::NearestTiesToEven, + 0x01 => rustc_apfloat::Round::TowardNegative, + 0x02 => rustc_apfloat::Round::TowardPositive, + 0x03 => rustc_apfloat::Round::TowardZero, + rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"), + }; + + let op0 = this.read_scalar(&this.project_index(&right, 0)?)?.to_f32()?; + let res = op0.round_to_integral(rounding).value; + this.write_scalar(Scalar::from_f32(res), &this.project_index(&dest, 0)?)?; + + for i in 1..dest_len { + this.copy_op( + &this.project_index(&left, i)?, + &this.project_index(&dest, i)?, + /*allow_transmute*/ false, + )?; + } + } + // Used to implement the _mm_floor_sd, _mm_ceil_sd and _mm_round_sd + // functions. Rounds the first element of `right` according to `rounding` + // and copies the remaining from `left`. + "round.sd" => { + let [left, right, rounding] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let (left, left_len) = this.operand_to_simd(left)?; + let (right, right_len) = this.operand_to_simd(right)?; + let (dest, dest_len) = this.place_to_simd(dest)?; + + assert_eq!(dest_len, left_len); + assert_eq!(dest_len, right_len); + + let rounding = match this.read_scalar(rounding)?.to_i32()? { + 0x00 | 0x80 => rustc_apfloat::Round::NearestTiesToEven, + 0x01 | 0x81 => rustc_apfloat::Round::TowardNegative, + 0x02 | 0x82 => rustc_apfloat::Round::TowardPositive, + 0x03 | 0x83 => rustc_apfloat::Round::TowardZero, + rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"), + }; + + let op0 = this.read_scalar(&this.project_index(&right, 0)?)?.to_f64()?; + let res = op0.round_to_integral(rounding).value; + this.write_scalar(Scalar::from_f64(res), &this.project_index(&dest, 0)?)?; + + for i in 1..dest_len { + this.copy_op( + &this.project_index(&left, i)?, + &this.project_index(&dest, i)?, + /*allow_transmute*/ false, + )?; + } + } + // Used to implement the _mm_minpos_epu16 function. + // Find the minimum unsinged 16-bit integer in `op` and + // returns its value and position. + "phminposuw" => { + let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let (op, op_len) = this.operand_to_simd(op)?; + let (dest, dest_len) = this.place_to_simd(dest)?; + + // Find minimum + let mut min_value = u16::MAX; + let mut min_index = 0; + for i in 0..op_len { + let op = this.read_scalar(&this.project_index(&op, i)?)?.to_u16()?; + if op < min_value { + min_value = op; + min_index = i; + } + } + + // Write value and index + this.write_scalar(Scalar::from_u16(min_value), &this.project_index(&dest, 0)?)?; + this.write_scalar( + Scalar::from_u16(min_index.try_into().unwrap()), + &this.project_index(&dest, 1)?, + )?; + // Fill remaining with zeros + for i in 2..dest_len { + this.write_scalar(Scalar::from_u16(0), &this.project_index(&dest, i)?)?; + } + } + // Used to implement the _mm_mpsadbw_epu8 function. + // Compute the sum of absolute differences of quadruplets of unsigned + // 8-bit integers in `left` and `right`, and store the 16-bit results + // in `right`. Quadruplets are selected from `left` and `right` with + // offsets specified in `imm`. + // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_mpsadbw_epu8 + "mpsadbw" => { + let [left, right, imm] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let (left, left_len) = this.operand_to_simd(left)?; + let (right, right_len) = this.operand_to_simd(right)?; + let (dest, dest_len) = this.place_to_simd(dest)?; + + assert_eq!(left_len, right_len); + assert_eq!(left_len, dest_len.checked_mul(2).unwrap()); + + let imm = this.read_scalar(imm)?.to_u8()?; + let left_offset = u64::from((imm >> 2) & 1).checked_mul(4).unwrap(); + let right_offset = u64::from(imm & 0b11).checked_mul(4).unwrap(); + + for i in 0..dest_len { + let left_offset = left_offset.checked_add(i).unwrap(); + let mut res: u16 = 0; + for j in 0..4 { + let left = this + .read_scalar( + &this.project_index(&left, left_offset.checked_add(j).unwrap())?, + )? + .to_u8()?; + let right = this + .read_scalar( + &this + .project_index(&right, right_offset.checked_add(j).unwrap())?, + )? + .to_u8()?; + res = res.checked_add(left.abs_diff(right).into()).unwrap(); + } + this.write_scalar(Scalar::from_u16(res), &this.project_index(&dest, i)?)?; + } + } + // Used to implement the _mm_testz_si128, _mm_testc_si128 + // and _mm_testnzc_si128 functions. + // Tests `op & mask == 0`, `op & mask == mask` or + // `op & mask != 0 && op & mask != mask` + "ptestz" | "ptestc" | "ptestnzc" => { + let [op, mask] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let (op, op_len) = this.operand_to_simd(op)?; + let (mask, mask_len) = this.operand_to_simd(mask)?; + + assert_eq!(op_len, mask_len); + + let f = match unprefixed_name { + "ptestz" => |op, mask| op & mask == 0, + "ptestc" => |op, mask| op & mask == mask, + "ptestnzc" => |op, mask| op & mask != 0 && op & mask != mask, + _ => unreachable!(), + }; + + let mut all_zero = true; + for i in 0..op_len { + let op = this.read_scalar(&this.project_index(&op, i)?)?.to_u64()?; + let mask = this.read_scalar(&this.project_index(&mask, i)?)?.to_u64()?; + all_zero &= f(op, mask); + } + + this.write_scalar(Scalar::from_i32(all_zero.into()), dest)?; + } + _ => return Ok(EmulateForeignItemResult::NotSupported), + } + Ok(EmulateForeignItemResult::NeedsJumping) + } +} diff --git a/tests/pass/intrinsics-x86-sse41.rs b/tests/pass/intrinsics-x86-sse41.rs new file mode 100644 index 0000000000..fe18847887 --- /dev/null +++ b/tests/pass/intrinsics-x86-sse41.rs @@ -0,0 +1,252 @@ +// Ignore everything except x86 and x86_64 +// Any additional target are added to CI should be ignored here +// (We cannot use `cfg`-based tricks here since the `target-feature` flags below only work on x86.) +//@ignore-target-aarch64 +//@ignore-target-arm +//@ignore-target-avr +//@ignore-target-s390x +//@ignore-target-thumbv7em +//@ignore-target-wasm32 +//@compile-flags: -C target-feature=+sse4.1 + +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; +use std::mem::transmute; + +fn main() { + assert!(is_x86_feature_detected!("sse4.1")); + + unsafe { + test_sse41(); + } +} + +#[target_feature(enable = "sse4.1")] +unsafe fn test_sse41() { + // Mostly copied from library/stdarch/crates/core_arch/src/x86/sse41.rs + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_insert_ps() { + let a = _mm_set1_ps(1.0); + let b = _mm_setr_ps(1.0, 2.0, 3.0, 4.0); + let r = _mm_insert_ps::<0b11_00_1100>(a, b); + let e = _mm_setr_ps(4.0, 1.0, 0.0, 0.0); + assert_eq_m128(r, e); + } + test_mm_insert_ps(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_packus_epi32() { + let a = _mm_setr_epi32(1, 2, 3, 4); + let b = _mm_setr_epi32(-1, -2, -3, -4); + let r = _mm_packus_epi32(a, b); + let e = _mm_setr_epi16(1, 2, 3, 4, 0, 0, 0, 0); + assert_eq_m128i(r, e); + } + test_mm_packus_epi32(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_dp_pd() { + let a = _mm_setr_pd(2.0, 3.0); + let b = _mm_setr_pd(1.0, 4.0); + let e = _mm_setr_pd(14.0, 0.0); + assert_eq_m128d(_mm_dp_pd::<0b00110001>(a, b), e); + } + test_mm_dp_pd(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_dp_ps() { + let a = _mm_setr_ps(2.0, 3.0, 1.0, 10.0); + let b = _mm_setr_ps(1.0, 4.0, 0.5, 10.0); + let e = _mm_setr_ps(14.5, 0.0, 14.5, 0.0); + assert_eq_m128(_mm_dp_ps::<0b01110101>(a, b), e); + } + test_mm_dp_ps(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_floor_sd() { + let a = _mm_setr_pd(2.5, 4.5); + let b = _mm_setr_pd(-1.5, -3.5); + let r = _mm_floor_sd(a, b); + let e = _mm_setr_pd(-2.0, 4.5); + assert_eq_m128d(r, e); + } + test_mm_floor_sd(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_floor_ss() { + let a = _mm_setr_ps(2.5, 4.5, 8.5, 16.5); + let b = _mm_setr_ps(-1.5, -3.5, -7.5, -15.5); + let r = _mm_floor_ss(a, b); + let e = _mm_setr_ps(-2.0, 4.5, 8.5, 16.5); + assert_eq_m128(r, e); + } + test_mm_floor_ss(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_ceil_sd() { + let a = _mm_setr_pd(1.5, 3.5); + let b = _mm_setr_pd(-2.5, -4.5); + let r = _mm_ceil_sd(a, b); + let e = _mm_setr_pd(-2.0, 3.5); + assert_eq_m128d(r, e); + } + test_mm_ceil_sd(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_ceil_ss() { + let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5); + let b = _mm_setr_ps(-2.5, -4.5, -8.5, -16.5); + let r = _mm_ceil_ss(a, b); + let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5); + assert_eq_m128(r, e); + } + test_mm_ceil_ss(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_round_sd() { + let a = _mm_setr_pd(1.5, 3.5); + let b = _mm_setr_pd(-2.5, -4.5); + let r = _mm_round_sd::<_MM_FROUND_TO_NEAREST_INT>(a, b); + let e = _mm_setr_pd(-2.0, 3.5); + assert_eq_m128d(r, e); + } + test_mm_round_sd(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_round_ss() { + let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5); + let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5); + let r = _mm_round_ss::<_MM_FROUND_TO_NEAREST_INT>(a, b); + let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5); + assert_eq_m128(r, e); + } + test_mm_round_ss(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_minpos_epu16() { + let a = _mm_setr_epi16(23, 18, 44, 97, 50, 13, 67, 66); + let r = _mm_minpos_epu16(a); + let e = _mm_setr_epi16(13, 5, 0, 0, 0, 0, 0, 0); + assert_eq_m128i(r, e); + + let a = _mm_setr_epi16(0, 18, 44, 97, 50, 13, 67, 66); + let r = _mm_minpos_epu16(a); + let e = _mm_setr_epi16(0, 0, 0, 0, 0, 0, 0, 0); + assert_eq_m128i(r, e); + } + test_mm_minpos_epu16(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_mpsadbw_epu8() { + let a = _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + + let r = _mm_mpsadbw_epu8::<0b000>(a, a); + let e = _mm_setr_epi16(0, 4, 8, 12, 16, 20, 24, 28); + assert_eq_m128i(r, e); + + let r = _mm_mpsadbw_epu8::<0b001>(a, a); + let e = _mm_setr_epi16(16, 12, 8, 4, 0, 4, 8, 12); + assert_eq_m128i(r, e); + + let r = _mm_mpsadbw_epu8::<0b100>(a, a); + let e = _mm_setr_epi16(16, 20, 24, 28, 32, 36, 40, 44); + assert_eq_m128i(r, e); + + let r = _mm_mpsadbw_epu8::<0b101>(a, a); + let e = _mm_setr_epi16(0, 4, 8, 12, 16, 20, 24, 28); + assert_eq_m128i(r, e); + + let r = _mm_mpsadbw_epu8::<0b111>(a, a); + let e = _mm_setr_epi16(32, 28, 24, 20, 16, 12, 8, 4); + assert_eq_m128i(r, e); + } + test_mm_mpsadbw_epu8(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_testz_si128() { + let a = _mm_set1_epi8(1); + let mask = _mm_set1_epi8(0); + let r = _mm_testz_si128(a, mask); + assert_eq!(r, 1); + + let a = _mm_set1_epi8(0b101); + let mask = _mm_set1_epi8(0b110); + let r = _mm_testz_si128(a, mask); + assert_eq!(r, 0); + + let a = _mm_set1_epi8(0b011); + let mask = _mm_set1_epi8(0b100); + let r = _mm_testz_si128(a, mask); + assert_eq!(r, 1); + } + test_mm_testz_si128(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_testc_si128() { + let a = _mm_set1_epi8(-1); + let mask = _mm_set1_epi8(0); + let r = _mm_testc_si128(a, mask); + assert_eq!(r, 1); + + let a = _mm_set1_epi8(0b101); + let mask = _mm_set1_epi8(0b110); + let r = _mm_testc_si128(a, mask); + assert_eq!(r, 0); + + let a = _mm_set1_epi8(0b101); + let mask = _mm_set1_epi8(0b100); + let r = _mm_testc_si128(a, mask); + assert_eq!(r, 1); + } + test_mm_testc_si128(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_mm_testnzc_si128() { + let a = _mm_set1_epi8(0); + let mask = _mm_set1_epi8(1); + let r = _mm_testnzc_si128(a, mask); + assert_eq!(r, 0); + + let a = _mm_set1_epi8(-1); + let mask = _mm_set1_epi8(0); + let r = _mm_testnzc_si128(a, mask); + assert_eq!(r, 0); + + let a = _mm_set1_epi8(0b101); + let mask = _mm_set1_epi8(0b110); + let r = _mm_testnzc_si128(a, mask); + assert_eq!(r, 1); + + let a = _mm_set1_epi8(0b101); + let mask = _mm_set1_epi8(0b101); + let r = _mm_testnzc_si128(a, mask); + assert_eq!(r, 0); + } + test_mm_testnzc_si128(); +} + +#[track_caller] +#[target_feature(enable = "sse")] +unsafe fn assert_eq_m128(a: __m128, b: __m128) { + let r = _mm_cmpeq_ps(a, b); + if _mm_movemask_ps(r) != 0b1111 { + panic!("{:?} != {:?}", a, b); + } +} + +#[track_caller] +#[target_feature(enable = "sse2")] +pub unsafe fn assert_eq_m128d(a: __m128d, b: __m128d) { + if _mm_movemask_pd(_mm_cmpeq_pd(a, b)) != 0b11 { + panic!("{:?} != {:?}", a, b); + } +} + +#[track_caller] +#[target_feature(enable = "sse2")] +pub unsafe fn assert_eq_m128i(a: __m128i, b: __m128i) { + assert_eq!(transmute::<_, [u64; 2]>(a), transmute::<_, [u64; 2]>(b)) +}