Skip to content

Commit

Permalink
Add Aarch64 simd support
Browse files Browse the repository at this point in the history
  • Loading branch information
CeleritasCelery committed Feb 6, 2023
1 parent e92188e commit f37a6d6
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 16 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ simd = [] # Enable explicit SIMD optimizations on supported platforms.
proptest = "1.0"
criterion = { version = "0.3", features = ["html_reports"] }

[profile.release]
lto = "thin"

#-----------------------------------------

[[bench]]
Expand Down
131 changes: 122 additions & 9 deletions src/byte_chunk.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64;

#[cfg(target_arch = "aarch64")]
use core::arch::aarch64;

// Which type to actually use at build time.
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
pub(crate) type Chunk = x86_64::__m128i;
#[cfg(any(not(feature = "simd"), not(any(target_arch = "x86_64"))))]
#[cfg(all(feature = "simd", target_arch = "aarch64"))]
pub(crate) type Chunk = aarch64::uint8x16_t;
#[cfg(any(
not(feature = "simd"),
not(any(target_arch = "x86_64", target_arch = "aarch64"))
))]
pub(crate) type Chunk = usize;

/// Interface for working with chunks of bytes at a time, providing the
Expand Down Expand Up @@ -253,6 +261,113 @@ impl ByteChunk for x86_64::__m128i {
}
}

#[cfg(target_arch = "aarch64")]
impl ByteChunk for aarch64::uint8x16_t {
const SIZE: usize = core::mem::size_of::<Self>();
const MAX_ACC: usize = 255;

#[inline(always)]
fn zero() -> Self {
unsafe { aarch64::vdupq_n_u8(0) }
}

#[inline(always)]
fn splat(n: u8) -> Self {
unsafe { aarch64::vdupq_n_u8(n) }
}

#[inline(always)]
fn is_zero(&self) -> bool {
unsafe { aarch64::vmaxvq_u8(*self) == 0 }
}

#[inline(always)]
fn shift_back_lex(&self, n: usize) -> Self {
unsafe {
match n {
1 => aarch64::vextq_u8(*self, Self::zero(), 1),
2 => aarch64::vextq_u8(*self, Self::zero(), 2),
_ => unreachable!(),
}
}
}

#[inline(always)]
fn shr(&self, n: usize) -> Self {
unsafe {
let u64_vec = aarch64::vreinterpretq_u64_u8(*self);
let result = match n {
1 => aarch64::vshrq_n_u64(u64_vec, 1),
_ => unreachable!(),
};
aarch64::vreinterpretq_u8_u64(result)
}
}

#[inline(always)]
fn cmp_eq_byte(&self, byte: u8) -> Self {
unsafe {
let equal = aarch64::vceqq_u8(*self, Self::splat(byte));
aarch64::vshrq_n_u8(equal, 7)
}
}

#[inline(always)]
fn bytes_between_127(&self, a: u8, b: u8) -> Self {
use aarch64::vreinterpretq_s8_u8 as cast;
unsafe {
let a_gt = aarch64::vcgtq_s8(cast(*self), cast(Self::splat(a)));
let b_gt = aarch64::vcltq_s8(cast(*self), cast(Self::splat(b)));
let in_range = aarch64::vandq_u8(a_gt, b_gt);
aarch64::vshrq_n_u8(in_range, 7)
}
}

#[inline(always)]
fn bitand(&self, other: Self) -> Self {
unsafe { aarch64::vandq_u8(*self, other) }
}

fn add(&self, other: Self) -> Self {
unsafe { aarch64::vaddq_u8(*self, other) }
}

#[inline(always)]
fn sub(&self, other: Self) -> Self {
unsafe { aarch64::vsubq_u8(*self, other) }
}

#[inline(always)]
fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self {
const END: i32 = Chunk::SIZE as i32 - 1;
match n {
0 => unsafe {
let lane = aarch64::vgetq_lane_u8(*self, END);
aarch64::vsetq_lane_u8(lane + 1, *self, END)
},
1 => unsafe {
let lane = aarch64::vgetq_lane_u8(*self, END - 1);
aarch64::vsetq_lane_u8(lane + 1, *self, END - 1)
},
_ => unreachable!(),
}
}

#[inline(always)]
fn dec_last_lex_byte(&self) -> Self {
const END: i32 = Chunk::SIZE as i32 - 1;
unsafe {
let last = aarch64::vgetq_lane_u8(*self, END);
aarch64::vsetq_lane_u8(last - 1, *self, END)
}
}

#[inline(always)]
fn sum_bytes(&self) -> usize {
unsafe { aarch64::vaddlvq_u8(*self).into() }
}
}

//=============================================================

#[cfg(test)]
Expand All @@ -277,17 +392,15 @@ mod tests {
assert_eq!(0x00_01_00_00_00_00_00_00, v.bytes_between_127(0x08, 0x7E));
}

#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[cfg(all(feature = "simd", any(target_arch = "x86_64", target_arch = "aarch64")))]
#[test]
fn sum_bytes_x86_64() {
use core::arch::x86_64::__m128i as T;

let ones = T::splat(1);
let mut acc = T::zero();
for _ in 0..T::MAX_ACC {
fn sum_bytes_simd() {
let ones = Chunk::splat(1);
let mut acc = Chunk::zero();
for _ in 0..Chunk::MAX_ACC {
acc = acc.add(ones);
}

assert_eq!(acc.sum_bytes(), T::SIZE * T::MAX_ACC);
assert_eq!(acc.sum_bytes(), Chunk::SIZE * Chunk::MAX_ACC);
}
}
26 changes: 19 additions & 7 deletions src/chars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,27 @@ pub(crate) fn count_impl<T: ByteChunk>(text: &[u8]) -> usize {
inv_count += ((byte & 0xC0) == 0x80) as usize;
}

// Take care of the middle bytes in big chunks.
for chunks in middle.chunks(T::MAX_ACC) {
let mut acc = T::zero();
for chunk in chunks.iter() {
acc = acc.add(chunk.bitand(T::splat(0xc0)).cmp_eq_byte(0x80));
}
inv_count += acc.sum_bytes();
#[inline(always)]
fn char_boundaries<T: ByteChunk>(val: &T) -> T {
val.bitand(T::splat(0xc0)).cmp_eq_byte(0x80)
}

// Take care of the middle bytes in big chunks. Loop unrolled
for chunks in middle.chunks_exact(4) {
let val1 = char_boundaries(&chunks[0]);
let val2 = char_boundaries(&chunks[1]);
let val3 = char_boundaries(&chunks[2]);
let val4 = char_boundaries(&chunks[3]);
inv_count += val1.add(val2).add(val3.add(val4)).sum_bytes();
}

// Take care of the rest of the chunk
let mut acc = T::zero();
for chunk in middle.chunks_exact(4).remainder() {
acc = acc.add(char_boundaries(chunk));
}
inv_count += acc.sum_bytes();

// Take care of unaligned bytes at the end.
for byte in end.iter() {
inv_count += ((byte & 0xC0) == 0x80) as usize;
Expand Down
1 change: 1 addition & 0 deletions src/lines_crlf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ fn to_byte_idx_impl<T: ByteChunk>(text: &[u8], line_idx: usize) -> usize {
///
/// The following unicode sequences are considered newlines by this function:
/// - u{000A} (Line Feed)
/// - u{000D} (Carriage Return)
#[inline(always)]
fn count_breaks_impl<T: ByteChunk>(text: &[u8]) -> usize {
// Get `middle` so we can do more efficient chunk-based counting.
Expand Down

0 comments on commit f37a6d6

Please sign in to comment.