From ae4a47778d88c5ab6a605fe240725320c8bb6060 Mon Sep 17 00:00:00 2001 From: Troy Hinckley Date: Fri, 3 Feb 2023 13:23:05 -0700 Subject: [PATCH] Add Aarch64 simd support --- Cargo.toml | 3 ++ src/byte_chunk.rs | 131 ++++++++++++++++++++++++++++++++++++++++++---- src/chars.rs | 26 ++++++--- src/lines_crlf.rs | 1 + 4 files changed, 146 insertions(+), 15 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 453d3e7..6aff63e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,9 @@ simd = [] # Enable explicit SIMD optimizations on supported platforms. proptest = "1.0" criterion = { version = "0.3", features = ["html_reports"] } +[profile.release] +lto = "thin" + #----------------------------------------- [[bench]] diff --git a/src/byte_chunk.rs b/src/byte_chunk.rs index 57a09de..2b7fc74 100644 --- a/src/byte_chunk.rs +++ b/src/byte_chunk.rs @@ -1,10 +1,18 @@ #[cfg(target_arch = "x86_64")] use core::arch::x86_64; +#[cfg(target_arch = "aarch64")] +use core::arch::aarch64; + // Which type to actually use at build time. #[cfg(all(feature = "simd", target_arch = "x86_64"))] pub(crate) type Chunk = x86_64::__m128i; -#[cfg(any(not(feature = "simd"), not(any(target_arch = "x86_64"))))] +#[cfg(all(feature = "simd", target_arch = "aarch64"))] +pub(crate) type Chunk = aarch64::uint8x16_t; +#[cfg(any( + not(feature = "simd"), + not(any(target_arch = "x86_64", target_arch = "aarch64")) +))] pub(crate) type Chunk = usize; /// Interface for working with chunks of bytes at a time, providing the @@ -253,6 +261,113 @@ impl ByteChunk for x86_64::__m128i { } } +#[cfg(target_arch = "aarch64")] +impl ByteChunk for aarch64::uint8x16_t { + const SIZE: usize = core::mem::size_of::(); + const MAX_ACC: usize = 255; + + #[inline(always)] + fn zero() -> Self { + unsafe { aarch64::vdupq_n_u8(0) } + } + + #[inline(always)] + fn splat(n: u8) -> Self { + unsafe { aarch64::vdupq_n_u8(n) } + } + + #[inline(always)] + fn is_zero(&self) -> bool { + unsafe { aarch64::vmaxvq_u8(*self) == 0 } + } + + #[inline(always)] + fn shift_back_lex(&self, n: usize) -> Self { + unsafe { + match n { + 1 => aarch64::vextq_u8(*self, Self::zero(), 1), + 2 => aarch64::vextq_u8(*self, Self::zero(), 2), + _ => unreachable!(), + } + } + } + + #[inline(always)] + fn shr(&self, n: usize) -> Self { + unsafe { + let u64_vec = aarch64::vreinterpretq_u64_u8(*self); + let result = match n { + 1 => aarch64::vshrq_n_u64(u64_vec, 1), + _ => unreachable!(), + }; + aarch64::vreinterpretq_u8_u64(result) + } + } + + #[inline(always)] + fn cmp_eq_byte(&self, byte: u8) -> Self { + unsafe { + let equal = aarch64::vceqq_u8(*self, Self::splat(byte)); + aarch64::vshrq_n_u8(equal, 7) + } + } + + #[inline(always)] + fn bytes_between_127(&self, a: u8, b: u8) -> Self { + use aarch64::vreinterpretq_s8_u8 as cast; + unsafe { + let a_gt = aarch64::vcgtq_s8(cast(*self), cast(Self::splat(a))); + let b_gt = aarch64::vcltq_s8(cast(*self), cast(Self::splat(b))); + let in_range = aarch64::vandq_u8(a_gt, b_gt); + aarch64::vshrq_n_u8(in_range, 7) + } + } + + #[inline(always)] + fn bitand(&self, other: Self) -> Self { + unsafe { aarch64::vandq_u8(*self, other) } + } + + fn add(&self, other: Self) -> Self { + unsafe { aarch64::vaddq_u8(*self, other) } + } + + #[inline(always)] + fn sub(&self, other: Self) -> Self { + unsafe { aarch64::vsubq_u8(*self, other) } + } + + #[inline(always)] + fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self { + const END: i32 = Chunk::SIZE as i32 - 1; + match n { + 0 => unsafe { + let lane = aarch64::vgetq_lane_u8(*self, END); + aarch64::vsetq_lane_u8(lane + 1, *self, END) + }, + 1 => unsafe { + let lane = aarch64::vgetq_lane_u8(*self, END - 1); + aarch64::vsetq_lane_u8(lane + 1, *self, END - 1) + }, + _ => unreachable!(), + } + } + + #[inline(always)] + fn dec_last_lex_byte(&self) -> Self { + const END: i32 = Chunk::SIZE as i32 - 1; + unsafe { + let last = aarch64::vgetq_lane_u8(*self, END); + aarch64::vsetq_lane_u8(last - 1, *self, END) + } + } + + #[inline(always)] + fn sum_bytes(&self) -> usize { + unsafe { aarch64::vaddlvq_u8(*self).into() } + } +} + //============================================================= #[cfg(test)] @@ -277,17 +392,15 @@ mod tests { assert_eq!(0x00_01_00_00_00_00_00_00, v.bytes_between_127(0x08, 0x7E)); } - #[cfg(all(feature = "simd", target_arch = "x86_64"))] + #[cfg(all(feature = "simd", any(target_arch = "x86_64", target_arch = "aarch64")))] #[test] - fn sum_bytes_x86_64() { - use core::arch::x86_64::__m128i as T; - - let ones = T::splat(1); - let mut acc = T::zero(); - for _ in 0..T::MAX_ACC { + fn sum_bytes_simd() { + let ones = Chunk::splat(1); + let mut acc = Chunk::zero(); + for _ in 0..Chunk::MAX_ACC { acc = acc.add(ones); } - assert_eq!(acc.sum_bytes(), T::SIZE * T::MAX_ACC); + assert_eq!(acc.sum_bytes(), Chunk::SIZE * Chunk::MAX_ACC); } } diff --git a/src/chars.rs b/src/chars.rs index 39ab2d3..60ac959 100644 --- a/src/chars.rs +++ b/src/chars.rs @@ -128,15 +128,29 @@ pub(crate) fn count_impl(text: &[u8]) -> usize { inv_count += ((byte & 0xC0) == 0x80) as usize; } - // Take care of the middle bytes in big chunks. - for chunks in middle.chunks(T::MAX_ACC) { - let mut acc = T::zero(); - for chunk in chunks.iter() { - acc = acc.add(chunk.bitand(T::splat(0xc0)).cmp_eq_byte(0x80)); + let char_boundaries = |val: &T| val.bitand(T::splat(0xc0)).cmp_eq_byte(0x80); + + // Take care of the middle bytes in big chunks. Loop unrolled + for chunks in middle.chunks_exact(4) { + let mut iter = chunks.iter(); + while let Some(chunk) = iter.next() { + let val1 = char_boundaries(chunk); + let val2 = char_boundaries(iter.next().unwrap()); + let val3 = char_boundaries(iter.next().unwrap()); + let val4 = char_boundaries(iter.next().unwrap()); + let val1_2 = val1.add(val2); + let val3_4 = val3.add(val4); + inv_count += val1_2.add(val3_4).sum_bytes(); } - inv_count += acc.sum_bytes(); } + // Take care of the rest of the chunk + let mut acc = T::zero(); + for chunk in middle.chunks_exact(4).remainder() { + acc = acc.add(char_boundaries(chunk)); + } + inv_count += acc.sum_bytes(); + // Take care of unaligned bytes at the end. for byte in end.iter() { inv_count += ((byte & 0xC0) == 0x80) as usize; diff --git a/src/lines_crlf.rs b/src/lines_crlf.rs index c4f3066..435012b 100644 --- a/src/lines_crlf.rs +++ b/src/lines_crlf.rs @@ -148,6 +148,7 @@ fn to_byte_idx_impl(text: &[u8], line_idx: usize) -> usize { /// /// The following unicode sequences are considered newlines by this function: /// - u{000A} (Line Feed) +/// - u{000D} (Carriage Return) #[inline(always)] fn count_breaks_impl(text: &[u8]) -> usize { // Get `middle` so we can do more efficient chunk-based counting.