Skip to content

Commit

Permalink
chore: support split by text len (#1002)
Browse files Browse the repository at this point in the history
* chore: support split by text len

* chore: update docs

* chore: update tests
  • Loading branch information
appflowy authored Nov 17, 2024
1 parent dcbc84d commit d798c81
Show file tree
Hide file tree
Showing 7 changed files with 393 additions and 300 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions services/appflowy-collaborate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ itertools = "0.12.0"
validator = "0.16.1"
rayon.workspace = true
tiktoken-rs = "0.6.0"
unicode-segmentation = "1.9.0"


[dev-dependencies]
rand = "0.8.5"
Expand Down
322 changes: 25 additions & 297 deletions services/appflowy-collaborate/src/indexer/document_indexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use collab_document::error::DocumentError;
use collab_entity::CollabType;
use database_entity::dto::{AFCollabEmbeddingParams, AFCollabEmbeddings, EmbeddingContentType};

use crate::config::get_env_var;
use crate::indexer::open_ai::{split_text_by_max_content_len, split_text_by_max_tokens};
use tiktoken_rs::CoreBPE;
use tracing::trace;
use uuid::Uuid;
Expand Down Expand Up @@ -54,12 +56,11 @@ impl Indexer for DocumentIndexer {
match result {
Ok(document_data) => {
let content = document_data.to_plain_text();
let max_tokens = self.embedding_model.default_dimensions() as usize;
create_embedding(
object_id,
content,
CollabType::Document,
max_tokens,
&self.embedding_model,
self.tokenizer.clone(),
)
.await
Expand Down Expand Up @@ -129,47 +130,35 @@ impl Indexer for DocumentIndexer {
}
}

/// ## Execution Time Comparison Results
///
/// The following results were observed when running `execution_time_comparison_tests`:
///
/// | Content Size (chars) | Direct Time (ms) | spawn_blocking Time (ms) |
/// |-----------------------|------------------|--------------------------|
/// | 500 | 1 | 1 |
/// | 1000 | 2 | 2 |
/// | 2000 | 5 | 5 |
/// | 5000 | 11 | 11 |
/// | 20000 | 49 | 48 |
///
/// ## Guidelines for Using `spawn_blocking`
///
/// - **Short Tasks (< 1 ms)**:
/// Use direct execution on the async runtime. The minimal execution time has negligible impact.
///
/// - **Moderate Tasks (1–10 ms)**:
/// - For infrequent or low-concurrency tasks, direct execution is acceptable.
/// - For frequent or high-concurrency tasks, consider using `spawn_blocking` to avoid delays.
///
/// - **Long Tasks (> 10 ms)**:
/// Always offload to a blocking thread with `spawn_blocking` to maintain runtime efficiency and responsiveness.
///
/// Related blog:
/// https://tokio.rs/blog/2020-04-preemption
/// https://ryhl.io/blog/async-what-is-blocking/
async fn create_embedding(
object_id: String,
content: String,
collab_type: CollabType,
max_tokens: usize,
embedding_model: &EmbeddingModel,
tokenizer: Arc<CoreBPE>,
) -> Result<Vec<AFCollabEmbeddingParams>, AppError> {
let split_contents = if content.len() < 500 {
split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref())?
let use_tiktoken = get_env_var("APPFLOWY_AI_CONTENT_SPLITTER_TIKTOKEN", "false")
.parse::<bool>()
.unwrap_or(false);

let split_contents = if use_tiktoken {
let max_tokens = embedding_model.default_dimensions() as usize;
if content.len() < 500 {
split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref())?
} else {
tokio::task::spawn_blocking(move || {
split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref())
})
.await??
}
} else {
tokio::task::spawn_blocking(move || {
split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref())
})
.await??
debug_assert!(matches!(
embedding_model,
EmbeddingModel::TextEmbedding3Small
));
// We assume that every token is ~4 bytes. We're going to split document content into fragments
// of ~2000 tokens each.
split_text_by_max_content_len(content, 8000)?
};

Ok(
Expand All @@ -186,264 +175,3 @@ async fn create_embedding(
.collect(),
)
}

fn split_text_by_max_tokens(
content: String,
max_tokens: usize,
tokenizer: &CoreBPE,
) -> Result<Vec<String>, AppError> {
if content.is_empty() {
return Ok(vec![]);
}

let token_ids = tokenizer.encode_ordinary(&content);
let total_tokens = token_ids.len();
if total_tokens <= max_tokens {
return Ok(vec![content]);
}

let mut chunks = Vec::new();
let mut start_idx = 0;
while start_idx < total_tokens {
let mut end_idx = (start_idx + max_tokens).min(total_tokens);
let mut decoded = false;
// Try to decode the chunk, adjust end_idx if decoding fails
while !decoded {
let token_chunk = &token_ids[start_idx..end_idx];
// Attempt to decode the current chunk
match tokenizer.decode(token_chunk.to_vec()) {
Ok(chunk_text) => {
chunks.push(chunk_text);
start_idx = end_idx;
decoded = true;
},
Err(_) => {
// If we can extend the chunk, do so
if end_idx < total_tokens {
end_idx += 1;
} else if start_idx + 1 < total_tokens {
// Skip the problematic token at start_idx
start_idx += 1;
end_idx = (start_idx + max_tokens).min(total_tokens);
} else {
// Cannot decode any further, break to avoid infinite loop
start_idx = total_tokens;
break;
}
},
}
}
}

Ok(chunks)
}

#[cfg(test)]
mod tests {
use crate::indexer::document_indexer::split_text_by_max_tokens;

use tiktoken_rs::cl100k_base;

#[test]
fn test_split_at_non_utf8() {
let max_tokens = 10; // Small number for testing

// Content with multibyte characters (emojis)
let content = "Hello 😃 World 🌍! This is a test 🚀.".to_string();
let tokenizer = cl100k_base().unwrap();
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();

// Ensure that we didn't split in the middle of a multibyte character
for content in params {
assert!(content.is_char_boundary(0));
assert!(content.is_char_boundary(content.len()));
}
}
#[test]
fn test_exact_boundary_split() {
let max_tokens = 5; // Set to 5 tokens for testing
let content = "The quick brown fox jumps over the lazy dog".to_string();
let tokenizer = cl100k_base().unwrap();
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();

let total_tokens = tokenizer.encode_ordinary(&content).len();
let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens;
assert_eq!(params.len(), expected_fragments);
}

#[test]
fn test_content_shorter_than_max_len() {
let max_tokens = 100;
let content = "Short content".to_string();
let tokenizer = cl100k_base().unwrap();
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();

assert_eq!(params.len(), 1);
assert_eq!(params[0], content);
}

#[test]
fn test_empty_content() {
let max_tokens = 10;
let content = "".to_string();
let tokenizer = cl100k_base().unwrap();
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();

assert_eq!(params.len(), 0);
}

#[test]
fn test_content_with_only_multibyte_characters() {
let max_tokens = 1; // Set to 1 token for testing
let content = "😀😃😄😁😆".to_string();
let tokenizer = cl100k_base().unwrap();
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();

let emojis: Vec<String> = content.chars().map(|c| c.to_string()).collect();
for (param, emoji) in params.iter().zip(emojis.iter()) {
assert_eq!(param, emoji);
}
}

#[test]
fn test_split_with_combining_characters() {
let max_tokens = 1; // Set to 1 token for testing
let content = "a\u{0301}e\u{0301}i\u{0301}o\u{0301}u\u{0301}".to_string(); // "áéíóú"
let tokenizer = cl100k_base().unwrap();
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();

let total_tokens = tokenizer.encode_ordinary(&content).len();
assert_eq!(params.len(), total_tokens);

let reconstructed_content = params.join("");
assert_eq!(reconstructed_content, content);
}

#[test]
fn test_large_content() {
let max_tokens = 1000;
let content = "a".repeat(5000); // 5000 characters
let tokenizer = cl100k_base().unwrap();
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();

let total_tokens = tokenizer.encode_ordinary(&content).len();
let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens;
assert_eq!(params.len(), expected_fragments);
}

#[test]
fn test_non_ascii_characters() {
let max_tokens = 2;
let content = "áéíóú".to_string();
let tokenizer = cl100k_base().unwrap();
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();

let total_tokens = tokenizer.encode_ordinary(&content).len();
let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens;
assert_eq!(params.len(), expected_fragments);

let reconstructed_content: String = params.concat();
assert_eq!(reconstructed_content, content);
}

#[test]
fn test_content_with_leading_and_trailing_whitespace() {
let max_tokens = 3;
let content = " abcde ".to_string();
let tokenizer = cl100k_base().unwrap();
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();

let total_tokens = tokenizer.encode_ordinary(&content).len();
let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens;
assert_eq!(params.len(), expected_fragments);

let reconstructed_content: String = params.concat();
assert_eq!(reconstructed_content, content);
}

#[test]
fn test_content_with_multiple_zero_width_joiners() {
let max_tokens = 1;
let content = "👩‍👩‍👧‍👧👨‍👨‍👦‍👦".to_string();
let tokenizer = cl100k_base().unwrap();
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();

let reconstructed_content: String = params.concat();
assert_eq!(reconstructed_content, content);
}

#[test]
fn test_content_with_long_combining_sequences() {
let max_tokens = 1;
let content = "a\u{0300}\u{0301}\u{0302}\u{0303}\u{0304}".to_string();
let tokenizer = cl100k_base().unwrap();
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();

let reconstructed_content: String = params.concat();
assert_eq!(reconstructed_content, content);
}
}

// #[cfg(test)]
// mod execution_time_comparison_tests {
// use crate::indexer::document_indexer::split_text_by_max_tokens;
// use rand::distributions::Alphanumeric;
// use rand::{thread_rng, Rng};
// use std::sync::Arc;
// use std::time::Instant;
// use tiktoken_rs::{cl100k_base, CoreBPE};
//
// #[tokio::test]
// async fn test_execution_time_comparison() {
// let tokenizer = Arc::new(cl100k_base().unwrap());
// let max_tokens = 100;
//
// let sizes = vec![500, 1000, 2000, 5000, 20000]; // Content sizes to test
// for size in sizes {
// let content = generate_random_string(size);
//
// // Measure direct execution time
// let direct_time = measure_direct_execution(content.clone(), max_tokens, &tokenizer);
//
// // Measure spawn_blocking execution time
// let spawn_blocking_time =
// measure_spawn_blocking_execution(content, max_tokens, Arc::clone(&tokenizer)).await;
//
// println!(
// "Content Size: {} | Direct Time: {}ms | spawn_blocking Time: {}ms",
// size, direct_time, spawn_blocking_time
// );
// }
// }
//
// // Measure direct execution time
// fn measure_direct_execution(content: String, max_tokens: usize, tokenizer: &CoreBPE) -> u128 {
// let start = Instant::now();
// split_text_by_max_tokens(content, max_tokens, tokenizer).unwrap();
// start.elapsed().as_millis()
// }
//
// // Measure `spawn_blocking` execution time
// async fn measure_spawn_blocking_execution(
// content: String,
// max_tokens: usize,
// tokenizer: Arc<CoreBPE>,
// ) -> u128 {
// let start = Instant::now();
// tokio::task::spawn_blocking(move || {
// split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref()).unwrap()
// })
// .await
// .unwrap();
// start.elapsed().as_millis()
// }
//
// pub fn generate_random_string(len: usize) -> String {
// let rng = thread_rng();
// rng
// .sample_iter(&Alphanumeric)
// .take(len)
// .map(char::from)
// .collect()
// }
// }
2 changes: 2 additions & 0 deletions services/appflowy-collaborate/src/indexer/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod document_indexer;
mod ext;
mod open_ai;
mod provider;

pub use document_indexer::DocumentIndexer;
pub use ext::DocumentDataExt;
pub use provider::*;
Loading

0 comments on commit d798c81

Please sign in to comment.