diff --git a/src/datatrove/pipeline/dedup/fast_mh3/Cargo.toml b/src/datatrove/pipeline/dedup/fast_mh3/Cargo.toml index f16ded56..3b9e5759 100644 --- a/src/datatrove/pipeline/dedup/fast_mh3/Cargo.toml +++ b/src/datatrove/pipeline/dedup/fast_mh3/Cargo.toml @@ -4,8 +4,6 @@ version = "0.1.0" edition = "2021" [dependencies] -indicatif = "0.17.7" -tokio-retry = "0.3" # AWS SDK aws-config = { version = "1.1.1", features = ["behavior-version-latest"] } aws-sdk-s3 = "1.1.1" @@ -19,9 +17,11 @@ anyhow = "1.0.75" # Byte reading/writing byteorder = "1.5.0" +# Progress bars +indicatif = "0.17.7" + # Async runtime and utilities tokio = { version = "1.33.0", features = ["full"] } -futures = "0.3" -# Sorting -itertools = "0.11.0" \ No newline at end of file +# Retries +tokio-retry = "0.3" \ No newline at end of file diff --git a/src/datatrove/pipeline/dedup/fast_mh3/src/main.rs b/src/datatrove/pipeline/dedup/fast_mh3/src/main.rs index 51092eaa..3b4b9cf6 100644 --- a/src/datatrove/pipeline/dedup/fast_mh3/src/main.rs +++ b/src/datatrove/pipeline/dedup/fast_mh3/src/main.rs @@ -12,6 +12,7 @@ use std::sync::{Arc, Mutex}; use tokio_retry::Retry; use tokio_retry::strategy::{ExponentialBackoff, jitter}; use std::time::Duration; +use tokio::sync::Semaphore; async fn with_retry(f: F) -> Result where @@ -42,6 +43,10 @@ struct Args { /// Total number of files to process #[arg(long)] total_files: usize, + + /// Total number of concurrent downloads + #[arg(long, default_value = "-1")] + downloads: usize, } #[derive(Debug, Clone)] @@ -68,51 +73,24 @@ impl S3Path { } #[derive(Debug)] -struct UnionFind { +struct UnionFindData { union_set: HashMap<(u32, u32), (u32, u32)>, set_size: HashMap<(u32, u32), usize>, } +#[derive(Debug)] +struct UnionFind { + data: Arc>, +} + impl UnionFind { fn new() -> Self { UnionFind { - union_set: HashMap::new(), - set_size: HashMap::new(), - } - } - - fn find_parent(&mut self, x: (u32, u32)) -> (u32, u32) { - if !self.union_set.contains_key(&x) || self.union_set.get(&x) == Some(&x) { - self.union_set.insert(x, x); - return x; + data: Arc::new(Mutex::new(UnionFindData { + union_set: HashMap::new(), + set_size: HashMap::new(), + })), } - - // Get parent and recurse - let parent = *self.union_set.get(&x).unwrap(); - let root = self.find_parent(parent); - self.union_set.insert(x, root); - root - } - - fn union(&mut self, v_a: (u32, u32), v_b: (u32, u32)) { - let mut root_a = self.find_parent(v_a); - let mut root_b = self.find_parent(v_b); - - if root_a == root_b { - return; - } - - let size_a = *self.set_size.get(&root_a).unwrap_or(&1); - let size_b = *self.set_size.get(&root_b).unwrap_or(&1); - - if size_a < size_b { - std::mem::swap(&mut root_a, &mut root_b); - } - - self.union_set.insert(root_b, root_a); - let new_size = size_a + size_b; - self.set_size.insert(root_a, new_size); - self.set_size.remove(&root_b); } } @@ -239,7 +217,6 @@ async fn list_s3_files(client: &Client, s3_path: &S3Path, total_files: usize) -> .map(|key| format!("s3://{}/{}", s3_path.bucket, key))) .collect(); - // Sort files lexicographically files.sort(); if files.len() != total_files { @@ -284,77 +261,44 @@ async fn download_and_parse_file(client: &Client, file_path: &str) -> Result Result<(usize, usize)> { - // Group docs by file number (just the doc IDs) - let mut nodes_by_file: HashMap> = HashMap::new(); - for &(file, doc) in union_find.union_set.keys() { - nodes_by_file.entry(file).or_default().push(doc); - } - - // Wrap our maps in Arc for shared read-only access - let union_set = Arc::new(union_find.union_set.clone()); // Clone happens only once - let sizes = Arc::new(union_find.set_size.clone()); // Clone happens only once - - println!("Processing {} files in parallel...", nodes_by_file.len()); - let pb = ProgressBar::new(union_find.union_set.len() as u64); - pb.set_style(ProgressStyle::default_bar() - .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})") - .unwrap() - .progress_chars("#>-")); - - let mut handles = Vec::new(); - for (file_number, docs) in nodes_by_file { - let client = client.clone(); - let output_path = output_path.clone(); - let union_set = Arc::clone(&union_set); // Just clones the Arc, not the data - let sizes = Arc::clone(&sizes); // Just clones the Arc, not the data - let pb = pb.clone(); - - let handle = task::spawn(async move { - process_single_file( - &client, - &output_path, - file_number, - docs, - &union_set, - &sizes, - &pb, - ).await - }); - handles.push(handle); - } - - let mut total_to_remove = 0; - let mut total_clusters = 0; - - for handle in handles { - let (to_remove, clusters) = handle.await??; - total_to_remove += to_remove; - total_clusters += clusters; - } - - pb.finish_with_message("Output writing complete"); - - Ok((total_to_remove, total_clusters)) -} - async fn process_single_file( client: &Client, output_path: &S3Path, file_number: u32, - docs: Vec, // just the doc_ids for this file - union_set: &HashMap<(u32, u32), (u32, u32)>, - sizes: &HashMap<(u32, u32), usize>, + union_find: &UnionFind, pb: &ProgressBar, ) -> Result<(usize, usize)> { let mut to_remove = 0; let mut clusters = 0; const BUFFER_THRESHOLD: usize = 5 * 1024 * 1024; + // Collect all the data we need under one lock + let nodes_data = { + let data = union_find.data.lock().unwrap(); + + // Collect docs and their data + let mut docs = data.union_set.keys() + .filter(|(f, _)| *f == file_number) + .map(|(_, d)| *d) + .collect::>(); + docs.sort_unstable(); + + // For each doc, collect its root and size + docs.into_iter().map(|doc| { + let node = (file_number, doc); + let mut current = node; + while let Some(&parent) = data.union_set.get(¤t) { + if parent == current { + break; + } + current = parent; + } + let root = current; + let size = *data.set_size.get(&root).unwrap_or(&1); + (doc, root, size) + }).collect::>() + }; // Lock is released here + let mut sizes_writer = S3StreamWriter::new( client, &output_path.bucket, @@ -369,22 +313,9 @@ async fn process_single_file( BUFFER_THRESHOLD, ).await?; - // Find root function - let find_root = |node: (u32, u32)| { - let mut current = node; - while let Some(&parent) = union_set.get(¤t) { - if parent == current { - break; - } - current = parent; - } - current - }; - - for doc in docs { + // Process the collected data without holding the lock + for (doc, root, size) in nodes_data { let node = (file_number, doc); - let parent = find_root(node); - let size = *sizes.get(&parent).unwrap_or(&1); // Write sizes let mut buffer = Vec::new(); @@ -393,14 +324,14 @@ async fn process_single_file( sizes_writer.write(&buffer).await?; // Handle removal markers - if node != parent { + if node != root { let mut remove_buffer = Vec::new(); remove_buffer.write_u32::(doc)?; remove_writer.write(&remove_buffer).await?; to_remove += 1; } - if node == parent { + if node == root { clusters += 1; } @@ -413,6 +344,62 @@ async fn process_single_file( Ok((to_remove, clusters)) } +async fn process_post_union( + client: &Client, + output_path: &S3Path, + union_find: &UnionFind, +) -> Result<(usize, usize)> { + // Get list of unique file numbers + let data = union_find.data.lock().unwrap(); + let files: Vec<_> = data.union_set.keys() + .map(|(f, _)| *f) + .collect::>() + .into_iter() + .collect(); +// files.sort_unstable(); + let total_nodes = data.union_set.len(); + drop(data); + + println!("Processing {} files in parallel...", files.len()); + let pb = ProgressBar::new(total_nodes as u64); + pb.set_style(ProgressStyle::default_bar() + .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})") + .unwrap() + .progress_chars("#>-")); + + let mut handles = Vec::new(); + for file_number in files { + let client = client.clone(); + let output_path = output_path.clone(); + let union_find = Arc::clone(&union_find.data); + let pb = pb.clone(); + + let handle = task::spawn(async move { + process_single_file( + &client, + &output_path, + file_number, + &UnionFind { data: union_find }, + &pb, + ).await + }); + handles.push(handle); + } + + let mut total_to_remove = 0; + let mut total_clusters = 0; + + for handle in handles { + let (to_remove, clusters) = handle.await??; + total_to_remove += to_remove; + total_clusters += clusters; + } + + pb.finish_with_message("Output writing complete"); + + Ok((total_to_remove, total_clusters)) +} + #[tokio::main] async fn main() -> Result<()> { let args = Args::parse(); @@ -425,8 +412,12 @@ async fn main() -> Result<()> { let files = list_s3_files(&client, &input_path, args.total_files).await?; - let union_find = Arc::new(Mutex::new(UnionFind::new())); - + let union_find = UnionFind::new(); + let semaphore = Arc::new(if args.downloads == -1 { + Semaphore::new(args.total_files) // Effectively unlimited + } else { + Semaphore::new(args.downloads as usize) + }); println!("Processing {} input files...", files.len()); let pb = ProgressBar::new(files.len() as u64); pb.set_style(ProgressStyle::default_bar() @@ -434,45 +425,90 @@ async fn main() -> Result<()> { .unwrap() .progress_chars("#>-")); -// let mut handles = Vec::new(); + let mut handles = Vec::new(); for file_path in files { let client = client.clone(); - let union_find = Arc::clone(&union_find); + let union_find = Arc::clone(&union_find.data); let pb = pb.clone(); + let semaphore = Arc::clone(&semaphore); - let tuples = download_and_parse_file(&client, &file_path).await?; - let mut uf = union_find.lock().unwrap(); - for (f1, d1, f2, d2) in tuples { - uf.union((f1, d1), (f2, d2)); - } - pb.inc(1); -// -// -// let handle = task::spawn(async move { -// let tuples = download_and_parse_file(&client, &file_path).await?; -// let mut uf = union_find.lock().unwrap(); -// for (f1, d1, f2, d2) in tuples { -// uf.union((f1, d1), (f2, d2)); -// } -// pb.inc(1); -// Ok::<(), anyhow::Error>(()) -// }); - -// handles.push(handle); + let handle = task::spawn(async move { + let _permit = semaphore.acquire().await?; + let tuples = download_and_parse_file(&client, &file_path).await?; + + let mut data = union_find.lock().unwrap(); + for (f1, d1, f2, d2) in tuples { + let v_a = (f1, d1); + let v_b = (f2, d2); + + let root_a = { + let mut current = v_a; + let mut path = Vec::new(); + while let Some(&parent) = data.union_set.get(¤t) { + if parent == current { + break; + } + path.push(current); + current = parent; + } + if !data.union_set.contains_key(¤t) { + data.union_set.insert(current, current); + } + for node in path { + data.union_set.insert(node, current); + } + current + }; + + let root_b = { + let mut current = v_b; + let mut path = Vec::new(); + while let Some(&parent) = data.union_set.get(¤t) { + if parent == current { + break; + } + path.push(current); + current = parent; + } + if !data.union_set.contains_key(¤t) { + data.union_set.insert(current, current); + } + for node in path { + data.union_set.insert(node, current); + } + current + }; + + if root_a != root_b { + let size_a = *data.set_size.get(&root_a).unwrap_or(&1); + let size_b = *data.set_size.get(&root_b).unwrap_or(&1); + + let (big_root, small_root) = if size_a >= size_b { + (root_a, root_b) + } else { + (root_b, root_a) + }; + + data.union_set.insert(small_root, big_root); + data.set_size.insert(big_root, size_a + size_b); + data.set_size.remove(&small_root); + } + } + drop(data); + pb.inc(1); + Ok::<(), anyhow::Error>(()) + }); + + handles.push(handle); } -// for handle in handles { -// handle.await??; -// } + for handle in handles { + handle.await??; + } pb.finish_with_message("File processing complete"); - let mut union_find = match Arc::try_unwrap(union_find) { - Ok(mutex) => mutex.into_inner().unwrap(), - Err(_) => panic!("Failed to unwrap Arc, some references still exist"), - }; - - let (to_remove, clusters) = process_post_union(&client, &output_path, &mut union_find).await?; + let (to_remove, clusters) = process_post_union(&client, &output_path, &union_find).await?; println!("Processing complete:"); println!(" Total clusters: {}", clusters);