Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize to_byte_idx #17

Merged
merged 6 commits into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 14 additions & 49 deletions benches/all.rs
Original file line number Diff line number Diff line change
@@ -1,65 +1,30 @@
#![allow(clippy::uninlined_format_args)]
use std::fs;
use std::{fs, path::Path};

use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
use str_indices::{chars, lines, lines_crlf, lines_lf, utf16};

fn all(c: &mut Criterion) {
let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("benches/text");
let read_text =
|name: &str| fs::read_to_string(root.join(name)).expect("cannot find benchmark text at");
// Load benchmark strings.
let test_strings = vec![
("en_0001", "E".into()),
(
"en_0010",
fs::read_to_string("benches/text/en_10.txt").expect("Cannot find benchmark text."),
),
(
"en_0100",
fs::read_to_string("benches/text/en_100.txt").expect("Cannot find benchmark text."),
),
(
"en_1000",
fs::read_to_string("benches/text/en_1000.txt").expect("Cannot find benchmark text."),
),
(
"en_10000",
fs::read_to_string("benches/text/en_1000.txt")
.expect("Cannot find benchmark text.")
.repeat(10),
),
("en_0010", read_text("en_10.txt")),
("en_0100", read_text("en_100.txt")),
("en_1000", read_text("en_1000.txt")),
("en_10000", read_text("en_1000.txt").repeat(10)),
("jp_0003", "日".into()),
(
"jp_0102",
fs::read_to_string("benches/text/jp_102.txt").expect("Cannot find benchmark text."),
),
(
"jp_1001",
fs::read_to_string("benches/text/jp_1001.txt").expect("Cannot find benchmark text."),
),
(
"jp_10000",
fs::read_to_string("benches/text/jp_1001.txt")
.expect("Cannot find benchmark text.")
.repeat(10),
),
("jp_0102", read_text("jp_102.txt")),
("jp_1001", read_text("jp_1001.txt")),
("jp_10000", read_text("jp_1001.txt").repeat(10)),
];

let line_strings = vec![
(
"lines_100",
fs::read_to_string("benches/text/lines.txt").expect("Cannot find benchmark text."),
),
(
"lines_1000",
fs::read_to_string("benches/text/lines.txt")
.expect("Cannot find benchmark text.")
.repeat(10),
),
(
"lines_10000",
fs::read_to_string("benches/text/lines.txt")
.expect("Cannot find benchmark text.")
.repeat(100),
),
("lines_100", read_text("lines.txt")),
("lines_1000", read_text("lines.txt").repeat(10)),
("lines_10000", read_text("lines.txt").repeat(100)),
];

//---------------------------------------------------------
Expand Down
141 changes: 74 additions & 67 deletions src/chars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn from_byte_idx(text: &str, byte_idx: usize) -> usize {
// Ensure the index is either a char boundary or is off the end of
// the text.
let mut i = byte_idx;
while Some(true) == bytes.get(i).map(|byte| (*byte & 0xC0) == 0x80) {
while Some(1) == bytes.get(i).map(is_trailing_byte) {
i -= 1;
}

Expand All @@ -39,56 +39,60 @@ pub fn from_byte_idx(text: &str, byte_idx: usize) -> usize {
/// Runs in O(N) time.
#[inline]
pub fn to_byte_idx(text: &str, char_idx: usize) -> usize {
to_byte_idx_impl::<Chunk>(text, char_idx)
to_byte_idx_impl::<Chunk>(text.as_bytes(), char_idx)
}

//-------------------------------------------------------------

#[inline(always)]
fn to_byte_idx_impl<T: ByteChunk>(text: &str, char_idx: usize) -> usize {
fn to_byte_idx_impl<T: ByteChunk>(text: &[u8], char_idx: usize) -> usize {
if text.len() <= T::SIZE {
// Bypass the more complex routine for short strings, where the
// complexity hurts performance.
let mut char_count = 0;
for (i, byte) in text.iter().enumerate() {
char_count += is_leading_byte(byte);
if char_count > char_idx {
return i;
}
}
return text.len();
}
// Get `middle` so we can do more efficient chunk-based counting.
// We can't use this to get `end`, however, because the start index of
// `end` actually depends on the accumulating char counts during the
// counting process.
let (start, middle, _) = unsafe { text.as_bytes().align_to::<T>() };
let (start, middle, _) = unsafe { text.align_to::<T>() };

let mut byte_count = 0;
let mut char_count = 0;

// Take care of any unaligned bytes at the beginning.
for byte in start.iter() {
char_count += ((*byte & 0xC0) != 0x80) as usize;
char_count += is_leading_byte(byte);
if char_count > char_idx {
break;
return byte_count;
}
byte_count += 1;
}

// Process chunks in the fast path.
let mut chunks = middle;
let mut max_round_len = char_idx.saturating_sub(char_count) / T::MAX_ACC;
while max_round_len > 0 && !chunks.is_empty() {
// Choose the largest number of chunks we can do this round
// that will neither overflow `max_acc` nor blast past the
// char we're looking for.
let round_len = T::MAX_ACC.min(max_round_len).min(chunks.len());
max_round_len -= round_len;
let round = &chunks[..round_len];
chunks = &chunks[round_len..];

// Process the chunks in this round.
let mut acc = T::zero();
for chunk in round.iter() {
acc = acc.add(chunk.bitand(T::splat(0xc0)).cmp_eq_byte(0x80));
}
char_count += (T::SIZE * round_len) - acc.sum_bytes();
byte_count += T::SIZE * round_len;
// Process chunks in the fast path. Ensure that we don't go past the number
// of chars we are counting towards
let fast_path_chunks = middle.len().min((char_idx - char_count) / T::SIZE);
let bytes = T::SIZE * 4;
for chunks in middle[..fast_path_chunks].chunks_exact(4) {
let val1 = count_trailing_chunk(chunks[0]);
let val2 = count_trailing_chunk(chunks[1]);
let val3 = count_trailing_chunk(chunks[2]);
let val4 = count_trailing_chunk(chunks[3]);
char_count += bytes - val1.add(val2).add(val3.add(val4)).sum_bytes();
byte_count += bytes;
}

// Process chunks in the slow path.
for chunk in chunks.iter() {
let new_char_count =
char_count + T::SIZE - chunk.bitand(T::splat(0xc0)).cmp_eq_byte(0x80).sum_bytes();
// Process the rest of chunks in the slow path.
let rest_idx = fast_path_chunks - fast_path_chunks % 4;
for chunk in middle[rest_idx..].iter() {
CeleritasCelery marked this conversation as resolved.
Show resolved Hide resolved
let new_char_count = char_count + T::SIZE - count_trailing_chunk(*chunk).sum_bytes();
if new_char_count >= char_idx {
break;
}
Expand All @@ -97,9 +101,9 @@ fn to_byte_idx_impl<T: ByteChunk>(text: &str, char_idx: usize) -> usize {
}

// Take care of any unaligned bytes at the end.
let end = &text.as_bytes()[byte_count..];
let end = &text[byte_count..];
for byte in end.iter() {
char_count += ((*byte & 0xC0) != 0x80) as usize;
char_count += is_leading_byte(byte);
if char_count > char_idx {
break;
}
Expand All @@ -114,46 +118,49 @@ pub(crate) fn count_impl<T: ByteChunk>(text: &[u8]) -> usize {
if text.len() < T::SIZE {
// Bypass the more complex routine for short strings, where the
// complexity hurts performance.
text.iter()
.map(|byte| ((byte & 0xC0) != 0x80) as usize)
.sum()
} else {
// Get `middle` for more efficient chunk-based counting.
let (start, middle, end) = unsafe { text.align_to::<T>() };

let mut inv_count = 0;

// Take care of unaligned bytes at the beginning.
for byte in start.iter() {
inv_count += ((byte & 0xC0) == 0x80) as usize;
}

#[inline(always)]
fn char_boundaries<T: ByteChunk>(val: T) -> T {
val.bitand(T::splat(0xc0)).cmp_eq_byte(0x80)
}
return text.iter().map(is_leading_byte).sum();
}
// Get `middle` for more efficient chunk-based counting.
let (start, middle, end) = unsafe { text.align_to::<T>() };

// Take care of the middle bytes in big chunks. Loop unrolled.
for chunks in middle.chunks_exact(4) {
let val1 = char_boundaries(chunks[0]);
let val2 = char_boundaries(chunks[1]);
let val3 = char_boundaries(chunks[2]);
let val4 = char_boundaries(chunks[3]);
inv_count += val1.add(val2).add(val3.add(val4)).sum_bytes();
}
let mut acc = T::zero();
for chunk in middle.chunks_exact(4).remainder() {
acc = acc.add(char_boundaries(*chunk));
}
inv_count += acc.sum_bytes();
let mut inv_count = 0;

// Take care of unaligned bytes at the end.
for byte in end.iter() {
inv_count += ((byte & 0xC0) == 0x80) as usize;
}
// Take care of unaligned bytes at the beginning.
inv_count += start.iter().map(is_trailing_byte).sum::<usize>();

text.len() - inv_count
// Take care of the middle bytes in big chunks. Loop unrolled.
for chunks in middle.chunks_exact(4) {
let val1 = count_trailing_chunk(chunks[0]);
let val2 = count_trailing_chunk(chunks[1]);
let val3 = count_trailing_chunk(chunks[2]);
let val4 = count_trailing_chunk(chunks[3]);
inv_count += val1.add(val2).add(val3.add(val4)).sum_bytes();
}
let mut acc = T::zero();
for chunk in middle.chunks_exact(4).remainder() {
acc = acc.add(count_trailing_chunk(*chunk));
}
inv_count += acc.sum_bytes();

// Take care of unaligned bytes at the end.
inv_count += end.iter().map(is_trailing_byte).sum::<usize>();

text.len() - inv_count
}

#[inline(always)]
fn is_leading_byte(byte: &u8) -> usize {
CeleritasCelery marked this conversation as resolved.
Show resolved Hide resolved
((byte & 0xC0) != 0x80) as usize
}

#[inline(always)]
fn is_trailing_byte(byte: &u8) -> usize {
((byte & 0xC0) == 0x80) as usize
}

#[inline(always)]
fn count_trailing_chunk<T: ByteChunk>(val: T) -> T {
CeleritasCelery marked this conversation as resolved.
Show resolved Hide resolved
val.bitand(T::splat(0xc0)).cmp_eq_byte(0x80)
}

//=============================================================
Expand Down