Skip to content

Commit

Permalink
added check
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Nov 26, 2024
1 parent 53b652e commit 1598cce
Showing 1 changed file with 0 additions and 195 deletions.
195 changes: 0 additions & 195 deletions src/datatrove/pipeline/dedup/fast_mh3/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ impl S3Path {
prefix: parts[1..].join("/"),
})
}

fn with_key(&self, key: &str) -> String {
format!("{}/{}", self.prefix.trim_end_matches('/'), key)
}
}

#[derive(Debug)]
Expand All @@ -94,110 +90,6 @@ impl UnionFind {
}
}

struct S3StreamWriter {
client: Client,
bucket: String,
key: String,
upload_id: String,
buffer: Vec<u8>,
part_number: i32,
completed_parts: Vec<CompletedPart>,
buffer_threshold: usize,
}

impl S3StreamWriter {
async fn new(
client: &Client,
bucket: &str,
key: &str,
buffer_threshold: usize,
) -> Result<Self> {
let create_multipart_upload_output = with_retry(|| async {
client
.create_multipart_upload()
.bucket(bucket)
.key(key)
.send()
.await
.context("Failed to create multipart upload")
}).await?;

Ok(Self {
client: client.clone(),
bucket: bucket.to_string(),
key: key.to_string(),
upload_id: create_multipart_upload_output.upload_id().unwrap().to_string(),
buffer: Vec::new(),
part_number: 1,
completed_parts: Vec::new(),
buffer_threshold,
})
}

async fn write(&mut self, data: &[u8]) -> Result<()> {
self.buffer.extend_from_slice(data);

if self.buffer.len() >= self.buffer_threshold {
self.flush().await?;
}

Ok(())
}

async fn flush(&mut self) -> Result<()> {
if self.buffer.is_empty() {
return Ok(());
}

let buffer_clone = self.buffer.clone();
let upload_part_output = with_retry(|| async {
let part_body = ByteStream::from(buffer_clone.clone());
self.client
.upload_part()
.bucket(&self.bucket)
.key(&self.key)
.upload_id(&self.upload_id)
.part_number(self.part_number)
.body(part_body)
.send()
.await
.context("Failed to upload part")
}).await?;

let completed_part = CompletedPart::builder()
.e_tag(upload_part_output.e_tag().unwrap_or_default())
.part_number(self.part_number)
.build();

self.completed_parts.push(completed_part);
self.part_number += 1;
self.buffer.clear();

Ok(())
}

async fn finalize(mut self) -> Result<()> {
self.flush().await?;

let completed_multipart_upload = CompletedMultipartUpload::builder()
.set_parts(Some(self.completed_parts.clone()))
.build();

with_retry(|| async {
self.client
.complete_multipart_upload()
.bucket(&self.bucket)
.key(&self.key)
.upload_id(&self.upload_id)
.multipart_upload(completed_multipart_upload.clone())
.send()
.await
.context("Failed to complete multipart upload")
}).await?;

Ok(())
}
}

async fn list_s3_files(client: &Client, s3_path: &S3Path, total_files: usize) -> Result<Vec<String>> {
let resp = with_retry(|| async {
Expand Down Expand Up @@ -261,90 +153,6 @@ async fn download_and_parse_file(client: &Client, file_path: &str) -> Result<Vec
Ok(tuples)
}

async fn process_single_file(
client: &Client,
output_path: &S3Path,
file_number: u32,
union_find: &UnionFind,
pb: &ProgressBar,
) -> Result<(usize, usize)> {
let mut to_remove = 0;
let mut clusters = 0;
const BUFFER_THRESHOLD: usize = 5 * 1024 * 1024;

// Collect all the data we need under one lock
let nodes_data = {
let data = union_find.data.lock().unwrap();

// Collect docs and their data
let mut docs = data.union_set.keys()
.filter(|(f, _)| *f == file_number)
.map(|(_, d)| *d)
.collect::<Vec<_>>();
docs.sort_unstable();

// For each doc, collect its root and size
docs.into_iter().map(|doc| {
let node = (file_number, doc);
let mut current = node;
while let Some(&parent) = data.union_set.get(&current) {
if parent == current {
break;
}
current = parent;
}
let root = current;
let size = *data.set_size.get(&root).unwrap_or(&1);
(doc, root, size)
}).collect::<Vec<_>>()
}; // Lock is released here

let mut sizes_writer = S3StreamWriter::new(
client,
&output_path.bucket,
&output_path.with_key(&format!("{:06}.sizes", file_number)),
BUFFER_THRESHOLD,
).await?;

let mut remove_writer = S3StreamWriter::new(
client,
&output_path.bucket,
&output_path.with_key(&format!("{:06}.remove", file_number)),
BUFFER_THRESHOLD,
).await?;

// Process the collected data without holding the lock
for (doc, root, size) in nodes_data {
let node = (file_number, doc);

// Write sizes
let mut buffer = Vec::new();
buffer.write_u32::<LittleEndian>(doc)?;
buffer.write_u32::<LittleEndian>(size as u32)?;
sizes_writer.write(&buffer).await?;

// Handle removal markers
if node != root {
let mut remove_buffer = Vec::new();
remove_buffer.write_u32::<LittleEndian>(doc)?;
remove_writer.write(&remove_buffer).await?;
to_remove += 1;
}

if node == root {
clusters += 1;
}

pb.inc(1);
}

sizes_writer.finalize().await?;
remove_writer.finalize().await?;

Ok((to_remove, clusters))
}


async fn process_single_remove_file(
client: &Client,
remove_file: String,
Expand Down Expand Up @@ -495,7 +303,6 @@ async fn main() -> Result<()> {
let client = Client::new(&config);

let input_path = S3Path::from_path(&args.input_folder)?;
let output_path = S3Path::from_path(&args.output_folder)?;

let files = list_s3_files(&client, &input_path, args.total_files).await?;

Expand Down Expand Up @@ -600,8 +407,6 @@ async fn main() -> Result<()> {
process_post_union(&client, remove_path, &union_find, max_concurrent).await?;

println!("Processing complete:");
println!(" Total clusters: {}", clusters);
println!(" Documents to remove: {}", to_remove);

Ok(())
}

0 comments on commit 1598cce

Please sign in to comment.