Skip to content

Commit

Permalink
Optimize lines_crlf
Browse files Browse the repository at this point in the history
This change is not ready to be merged because it does not yet support x86. I
don't know how to do the operation we need with SSE intrinsics.

This new approach adds a another method to ByteChunk that lets you shift a value
across two vectors. This let's us turn a 2-pass counting algorithim into a
single pass. Making this change alone resulted in a 40% speedup.

There is also some loop unrolling, but there was not a large difference in
unrolling over a factor of 2.
  • Loading branch information
CeleritasCelery committed Jun 10, 2024
1 parent dff3f96 commit b31888a
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 76 deletions.
27 changes: 27 additions & 0 deletions src/byte_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ pub(crate) trait ByteChunk: Copy + Clone {
/// Shifts bytes back lexographically by n bytes.
fn shift_back_lex(&self, n: usize) -> Self;

/// Shifts the bottom byte of self into the top byte of n.
fn shift_across(&self, n: Self) -> Self;

/// Shifts bits to the right by n bits.
fn shr(&self, n: usize) -> Self;

Expand Down Expand Up @@ -100,6 +103,16 @@ impl ByteChunk for usize {
}
}

#[inline(always)]
fn shift_across(&self, n: Self) -> Self {
let size = (Self::SIZE - 1) * 8;
if cfg!(target_endian = "little") {
(*self >> size) | (n << 8)
} else {
(*self << size) | (n >> 8)
}
}

#[inline(always)]
fn shr(&self, n: usize) -> Self {
*self >> n
Expand Down Expand Up @@ -198,6 +211,15 @@ impl ByteChunk for x86_64::__m128i {
}
}

#[inline(always)]
fn shift_across(&self, n: Self) -> Self {
unsafe {
let bottom_byte = x86_64::_mm_srli_si128(*self, 15);
let rest_shifted = x86_64::_mm_slli_si128(n, 1);
x86_64::_mm_or_si128(bottom_byte, rest_shifted)
}
}

#[inline(always)]
fn shr(&self, n: usize) -> Self {
match n {
Expand Down Expand Up @@ -292,6 +314,11 @@ impl ByteChunk for aarch64::uint8x16_t {
}
}

#[inline(always)]
fn shift_across(&self, n: Self) -> Self {
unsafe { aarch64::vextq_u8(*self, n, 15) }
}

#[inline(always)]
fn shr(&self, n: usize) -> Self {
unsafe {
Expand Down
166 changes: 90 additions & 76 deletions src/lines_crlf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ pub fn to_byte_idx(text: &str, line_idx: usize) -> usize {
}

//-------------------------------------------------------------
const LF: u8 = b'\n';
const CR: u8 = b'\r';

#[inline(always)]
fn to_byte_idx_impl<T: ByteChunk>(text: &[u8], line_idx: usize) -> usize {
Expand All @@ -67,76 +69,79 @@ fn to_byte_idx_impl<T: ByteChunk>(text: &[u8], line_idx: usize) -> usize {
let mut byte_count = 0;
let mut break_count = 0;

// Take care of any unaligned bytes at the beginning.
for byte in start.iter() {
let mut last_was_cr = false;
for byte in start.iter().copied() {
let is_lf = byte == LF;
let is_cr = byte == CR;
if break_count == line_idx {
break;
if last_was_cr && is_lf {
byte_count += 1;
}
return byte_count;
}
break_count +=
(*byte == 0x0A || (*byte == 0x0D && text.get(byte_count + 1) != Some(&0x0A))) as usize;
if is_cr || (is_lf && !last_was_cr) {
break_count += 1;
}
last_was_cr = is_cr;
byte_count += 1;
}

// Process chunks in the fast path.
let mut chunks = middle;
let mut max_round_len = (line_idx - break_count) / T::MAX_ACC;
while max_round_len > 0 && !chunks.is_empty() {
// Choose the largest number of chunks we can do this round
// that will neither overflow `max_acc` nor blast past the
// remaining line breaks we're looking for.
let round_len = T::MAX_ACC.min(max_round_len).min(chunks.len());
max_round_len -= round_len;
let round = &chunks[..round_len];
chunks = &chunks[round_len..];

// Process the chunks in this round.
let mut acc = T::zero();
for chunk in round.iter() {
let lf_flags = chunk.cmp_eq_byte(0x0A);
let cr_flags = chunk.cmp_eq_byte(0x0D);
let crlf_flags = cr_flags.bitand(lf_flags.shift_back_lex(1));
acc = acc.add(lf_flags).add(cr_flags.sub(crlf_flags));
}
break_count += acc.sum_bytes();

// Handle CRLFs at chunk boundaries in this round.
let mut i = byte_count;
while i < (byte_count + T::SIZE * round_len) {
i += T::SIZE;
break_count -= (text[i - 1] == 0x0D && text.get(i) == Some(&0x0A)) as usize;
// Process the chunks 2 at a time
let mut chunk_count = 0;
let mut prev = T::splat(last_was_cr as u8);
for chunks in middle.chunks_exact(2) {
let lf_flags0 = chunks[0].cmp_eq_byte(LF);
let cr_flags0 = chunks[0].cmp_eq_byte(CR);
let crlf_flags0 = prev.shift_across(cr_flags0).bitand(lf_flags0);

let lf_flags1 = chunks[1].cmp_eq_byte(LF);
let cr_flags1 = chunks[1].cmp_eq_byte(CR);
let crlf_flags1 = cr_flags0.shift_across(cr_flags1).bitand(lf_flags1);
let new_break_count = break_count
+ lf_flags0
.add(cr_flags0)
.add(lf_flags1)
.add(cr_flags1)
.sub(crlf_flags0)
.sub(crlf_flags1)
.sum_bytes();
if new_break_count >= line_idx {
break;
}

byte_count += T::SIZE * round_len;
break_count = new_break_count;
byte_count += T::SIZE * 2;
chunk_count += 2;
prev = cr_flags1;
}

// Process chunks in the slow path.
for chunk in chunks.iter() {
let breaks = {
let lf_flags = chunk.cmp_eq_byte(0x0A);
let cr_flags = chunk.cmp_eq_byte(0x0D);
let crlf_flags = cr_flags.bitand(lf_flags.shift_back_lex(1));
lf_flags.add(cr_flags.sub(crlf_flags)).sum_bytes()
};
let boundary_crlf = {
let i = byte_count + T::SIZE;
(text[i - 1] == 0x0D && text.get(i) == Some(&0x0A)) as usize
};
let new_break_count = break_count + breaks - boundary_crlf;
// Process the rest of the chunks
for chunk in middle[chunk_count..].iter() {
let lf_flags = chunk.cmp_eq_byte(LF);
let cr_flags = chunk.cmp_eq_byte(CR);
let crlf_flags = prev.shift_across(cr_flags).bitand(lf_flags);
let new_break_count = break_count + lf_flags.add(cr_flags).sub(crlf_flags).sum_bytes();
if new_break_count >= line_idx {
break;
}
break_count = new_break_count;
byte_count += T::SIZE;
prev = cr_flags;
}

// Take care of any unaligned bytes at the end.
let end = &text[byte_count..];
for byte in end.iter() {
last_was_cr = text.get(byte_count.saturating_sub(1)) == Some(&CR);
for byte in text[byte_count..].iter().copied() {
let is_lf = byte == LF;
let is_cr = byte == CR;
if break_count == line_idx {
if last_was_cr && is_lf {
byte_count += 1;
}
break;
}
break_count +=
(*byte == 0x0A || (*byte == 0x0D && text.get(byte_count + 1) != Some(&0x0A))) as usize;
if is_cr || (is_lf && !last_was_cr) {
break_count += 1;
}
last_was_cr = is_cr;
byte_count += 1;
}

Expand All @@ -159,39 +164,48 @@ fn count_breaks_impl<T: ByteChunk>(text: &[u8]) -> usize {
// Take care of unaligned bytes at the beginning.
let mut last_was_cr = false;
for byte in start.iter().copied() {
let is_lf = byte == 0x0A;
let is_cr = byte == 0x0D;
count += (is_cr | (is_lf & !last_was_cr)) as usize;
let is_lf = byte == LF;
let is_cr = byte == CR;
if is_cr || (is_lf && !last_was_cr) {
count += 1;
}
last_was_cr = is_cr;
}

// Take care of the middle bytes in big chunks.
for chunks in middle.chunks(T::MAX_ACC) {
let mut acc = T::zero();
for chunk in chunks.iter() {
let lf_flags = chunk.cmp_eq_byte(0x0A);
let cr_flags = chunk.cmp_eq_byte(0x0D);
let crlf_flags = cr_flags.bitand(lf_flags.shift_back_lex(1));
acc = acc.add(lf_flags).add(cr_flags.sub(crlf_flags));
}
count += acc.sum_bytes();
let mut prev = T::splat(last_was_cr as u8);
for chunks in middle.chunks_exact(2) {
let lf_flags0 = chunks[0].cmp_eq_byte(LF);
let cr_flags0 = chunks[0].cmp_eq_byte(CR);
let crlf_flags0 = prev.shift_across(cr_flags0).bitand(lf_flags0);

let lf_flags1 = chunks[1].cmp_eq_byte(LF);
let cr_flags1 = chunks[1].cmp_eq_byte(CR);
let crlf_flags1 = cr_flags0.shift_across(cr_flags1).bitand(lf_flags1);
count += lf_flags0
.add(cr_flags0)
.sub(crlf_flags0)
.add(lf_flags1)
.add(cr_flags1)
.sub(crlf_flags1)
.sum_bytes();
prev = cr_flags1;
}

// Check chunk boundaries for CRLF.
let mut i = start.len();
while i < (text.len() - end.len()) {
if text[i] == 0x0A {
count -= (text.get(i.saturating_sub(1)) == Some(&0x0D)) as usize;
}
i += T::SIZE;
if let Some(chunk) = middle.chunks_exact(2).remainder().iter().next() {
let lf_flags = chunk.cmp_eq_byte(LF);
let cr_flags = chunk.cmp_eq_byte(CR);
let crlf_flags = prev.shift_across(cr_flags).bitand(lf_flags);
count += lf_flags.add(cr_flags).sub(crlf_flags).sum_bytes();
}

// Take care of unaligned bytes at the end.
let mut last_was_cr = text.get((text.len() - end.len()).saturating_sub(1)) == Some(&0x0D);
last_was_cr = text.get((text.len() - end.len()).saturating_sub(1)) == Some(&CR);
for byte in end.iter().copied() {
let is_lf = byte == 0x0A;
let is_cr = byte == 0x0D;
count += (is_cr | (is_lf & !last_was_cr)) as usize;
let is_lf = byte == LF;
let is_cr = byte == CR;
if is_cr || (is_lf && !last_was_cr) {
count += 1;
}
last_was_cr = is_cr;
}

Expand Down
4 changes: 4 additions & 0 deletions tests/proptests_lines_crlf.proptest-regressions
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@
# everyone who runs the test benefits from these saved cases.
cc 7332c47dc044064cfb86ac6d94c45441f6dcdab699402718874116b512062e0a # shrinks to ref text = "a\n\rあ\r\n\rあa\n\rあ\n\nあ\r\nああ\r\nあ🐸a\ra\r🐸", idx = 48
cc f9a60685ebb7fc415f3d941598e75c90a0a7db0f0cd6673e092cf78f43a60fa3 # shrinks to ref text = "あああ\r\r\r\n\naaあ\naあ\r🐸🐸🐸\n\naあaaa\n🐸\na\n\na\n\n\n\r\n🐸\r\nあ\raaああ🐸a\naa🐸\n🐸\ra\na🐸a🐸\n\ra\n🐸🐸🐸\r\ra🐸あ\n\n🐸🐸aaあ\r\rあ🐸\rあa\n\r\n🐸🐸\n\nあ\rあa\nあa\rあaa\nあa\r\r\r\n\r\n\r\nあa🐸\r\n\r\r\na\r"
cc 3e5415f317b24c22a479d0e56a4705ad4c4c0cb060ee5610f94ef78fe3fda588 # shrinks to ref text = "\n\r\raa🐸\ra\n\raa🐸あ\r\n\n\nあ\nあ🐸🐸あ🐸\n🐸あああ🐸aaあ🐸a\n\rあああ🐸あ🐸\n\n\r🐸a\rああ\r🐸ああ\n\r\r\r🐸🐸あ🐸\r\r\n🐸🐸あaa\r\na\r\n🐸\na🐸\raあ\r\naあa\nああ\r\r🐸a\n\raa\rあ🐸\rあ\n\n\rあ\n\r\n🐸ああ\n🐸aあ\r\r\n\nあ🐸a\naaa🐸🐸あa", idx = 272
cc 9fe71a1c06b7b4791fab564d43a9fe7e0d3047302e6c9228852e91f638833aae # shrinks to ref text = "\n\n\nあああaa\n🐸\rああ\r🐸🐸あ\n\nああ🐸\n\rああ\r\n🐸\nああ🐸\n\ra\r\r\naあ\n🐸\rあ\n\r", idx = 96
cc b06906825a8db53b794a8dfbd9fc17ee80d3f722528a86644ad4457ba8ab12d8 # shrinks to ref text = "🐸\n\ra🐸\n🐸aaa\r🐸aaaa🐸🐸\nあ\r\naあ\rあa\raあ\rあ\rあa\ra\n\nあa\rあ\r\rあa\nああaa\n\r🐸\na\n🐸あ🐸a\r\n\n\naあa\raaa\r🐸🐸\r\rあaあ🐸\n🐸🐸\rあ\nああああaa\na\r\n\raa\n🐸\naあ\raaaa\n\n\naaあa🐸🐸🐸\ra\nあ\n\nあ🐸ああ🐸a🐸a\rあ\rあ🐸🐸あ🐸あa\n\raaあ\r🐸あ\ra\na🐸🐸\r\rあa\r\n\n\nあ\r\r\r🐸あ🐸🐸🐸\r\rあaあ🐸\n\n\r🐸\n🐸あ\r"
cc 8e3c6ad1951e8839368157cf7f595661bbf5b004ff19800499d42387e6b8fe56 # shrinks to ref text = "🐸あ🐸🐸🐸\ra\nあ🐸aああ\rあ🐸ああ\n\rあa\r\raa\n\naa🐸\ra\n\r🐸\nあ\n\r🐸\raaa🐸🐸🐸aあ\raあ🐸\na\n🐸aaあ\naあ\n🐸🐸あa\r\r🐸\r🐸a🐸aaaああ🐸\naあ\na\n\naa\n\r🐸\r\ra\r🐸\n\n\n\n\n\r\nあ\rあa\nあ\ra🐸\r\n🐸\n\ra\r\r\n\n\nあa\r\rあ🐸\n🐸\n🐸\r\rあ🐸\r🐸🐸あ🐸a\nあああ🐸🐸aaaa🐸\n🐸\r\r\na\n\r🐸a🐸aa\n\n\r\r\r\r🐸\r🐸🐸", idx = 69

0 comments on commit b31888a

Please sign in to comment.