Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Aarch64 SIMD support #12

Merged
merged 1 commit into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ simd = [] # Enable explicit SIMD optimizations on supported platforms.
proptest = "1.0"
criterion = { version = "0.3", features = ["html_reports"] }

[profile.release]
lto = "thin"

#-----------------------------------------

[[bench]]
Expand Down
131 changes: 122 additions & 9 deletions src/byte_chunk.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64;

#[cfg(target_arch = "aarch64")]
use core::arch::aarch64;

// Which type to actually use at build time.
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
pub(crate) type Chunk = x86_64::__m128i;
#[cfg(any(not(feature = "simd"), not(any(target_arch = "x86_64"))))]
#[cfg(all(feature = "simd", target_arch = "aarch64"))]
pub(crate) type Chunk = aarch64::uint8x16_t;
#[cfg(any(
not(feature = "simd"),
not(any(target_arch = "x86_64", target_arch = "aarch64"))
))]
pub(crate) type Chunk = usize;

/// Interface for working with chunks of bytes at a time, providing the
Expand Down Expand Up @@ -253,6 +261,113 @@ impl ByteChunk for x86_64::__m128i {
}
}

#[cfg(target_arch = "aarch64")]
impl ByteChunk for aarch64::uint8x16_t {
const SIZE: usize = core::mem::size_of::<Self>();
const MAX_ACC: usize = 255;

#[inline(always)]
fn zero() -> Self {
unsafe { aarch64::vdupq_n_u8(0) }
}

#[inline(always)]
fn splat(n: u8) -> Self {
unsafe { aarch64::vdupq_n_u8(n) }
}

#[inline(always)]
fn is_zero(&self) -> bool {
unsafe { aarch64::vmaxvq_u8(*self) == 0 }
}

#[inline(always)]
fn shift_back_lex(&self, n: usize) -> Self {
unsafe {
match n {
1 => aarch64::vextq_u8(*self, Self::zero(), 1),
2 => aarch64::vextq_u8(*self, Self::zero(), 2),
_ => unreachable!(),
}
}
}

#[inline(always)]
fn shr(&self, n: usize) -> Self {
unsafe {
let u64_vec = aarch64::vreinterpretq_u64_u8(*self);
let result = match n {
1 => aarch64::vshrq_n_u64(u64_vec, 1),
_ => unreachable!(),
};
aarch64::vreinterpretq_u8_u64(result)
}
}

#[inline(always)]
fn cmp_eq_byte(&self, byte: u8) -> Self {
unsafe {
let equal = aarch64::vceqq_u8(*self, Self::splat(byte));
aarch64::vshrq_n_u8(equal, 7)
}
}

#[inline(always)]
fn bytes_between_127(&self, a: u8, b: u8) -> Self {
use aarch64::vreinterpretq_s8_u8 as cast;
unsafe {
let a_gt = aarch64::vcgtq_s8(cast(*self), cast(Self::splat(a)));
let b_gt = aarch64::vcltq_s8(cast(*self), cast(Self::splat(b)));
let in_range = aarch64::vandq_u8(a_gt, b_gt);
aarch64::vshrq_n_u8(in_range, 7)
}
}

#[inline(always)]
fn bitand(&self, other: Self) -> Self {
unsafe { aarch64::vandq_u8(*self, other) }
}

fn add(&self, other: Self) -> Self {
unsafe { aarch64::vaddq_u8(*self, other) }
}

#[inline(always)]
fn sub(&self, other: Self) -> Self {
unsafe { aarch64::vsubq_u8(*self, other) }
}

#[inline(always)]
fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self {
const END: i32 = Chunk::SIZE as i32 - 1;
match n {
0 => unsafe {
let lane = aarch64::vgetq_lane_u8(*self, END);
aarch64::vsetq_lane_u8(lane + 1, *self, END)
},
1 => unsafe {
let lane = aarch64::vgetq_lane_u8(*self, END - 1);
aarch64::vsetq_lane_u8(lane + 1, *self, END - 1)
},
_ => unreachable!(),
}
}

#[inline(always)]
fn dec_last_lex_byte(&self) -> Self {
const END: i32 = Chunk::SIZE as i32 - 1;
unsafe {
let last = aarch64::vgetq_lane_u8(*self, END);
aarch64::vsetq_lane_u8(last - 1, *self, END)
}
}

#[inline(always)]
fn sum_bytes(&self) -> usize {
unsafe { aarch64::vaddlvq_u8(*self).into() }
}
}

//=============================================================

#[cfg(test)]
Expand All @@ -277,17 +392,15 @@ mod tests {
assert_eq!(0x00_01_00_00_00_00_00_00, v.bytes_between_127(0x08, 0x7E));
}

#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[cfg(all(feature = "simd", any(target_arch = "x86_64", target_arch = "aarch64")))]
#[test]
fn sum_bytes_x86_64() {
use core::arch::x86_64::__m128i as T;

let ones = T::splat(1);
let mut acc = T::zero();
for _ in 0..T::MAX_ACC {
fn sum_bytes_simd() {
let ones = Chunk::splat(1);
let mut acc = Chunk::zero();
for _ in 0..Chunk::MAX_ACC {
acc = acc.add(ones);
}

assert_eq!(acc.sum_bytes(), T::SIZE * T::MAX_ACC);
assert_eq!(acc.sum_bytes(), Chunk::SIZE * Chunk::MAX_ACC);
}
}
26 changes: 19 additions & 7 deletions src/chars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,27 @@ pub(crate) fn count_impl<T: ByteChunk>(text: &[u8]) -> usize {
inv_count += ((byte & 0xC0) == 0x80) as usize;
}

// Take care of the middle bytes in big chunks.
for chunks in middle.chunks(T::MAX_ACC) {
let mut acc = T::zero();
for chunk in chunks.iter() {
acc = acc.add(chunk.bitand(T::splat(0xc0)).cmp_eq_byte(0x80));
}
inv_count += acc.sum_bytes();
#[inline(always)]
fn char_boundaries<T: ByteChunk>(val: &T) -> T {
val.bitand(T::splat(0xc0)).cmp_eq_byte(0x80)
}

// Take care of the middle bytes in big chunks. Loop unrolled
for chunks in middle.chunks_exact(4) {
let val1 = char_boundaries(&chunks[0]);
let val2 = char_boundaries(&chunks[1]);
let val3 = char_boundaries(&chunks[2]);
let val4 = char_boundaries(&chunks[3]);
inv_count += val1.add(val2).add(val3.add(val4)).sum_bytes();
}

// Take care of the rest of the chunk
let mut acc = T::zero();
for chunk in middle.chunks_exact(4).remainder() {
acc = acc.add(char_boundaries(chunk));
}
inv_count += acc.sum_bytes();

// Take care of unaligned bytes at the end.
for byte in end.iter() {
inv_count += ((byte & 0xC0) == 0x80) as usize;
Expand Down
1 change: 1 addition & 0 deletions src/lines_crlf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ fn to_byte_idx_impl<T: ByteChunk>(text: &[u8], line_idx: usize) -> usize {
///
/// The following unicode sequences are considered newlines by this function:
/// - u{000A} (Line Feed)
/// - u{000D} (Carriage Return)
CeleritasCelery marked this conversation as resolved.
Show resolved Hide resolved
#[inline(always)]
fn count_breaks_impl<T: ByteChunk>(text: &[u8]) -> usize {
// Get `middle` so we can do more efficient chunk-based counting.
Expand Down