diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 54418441f..621e68cd9 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -61,15 +61,15 @@ def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset: columns = dataset.columns() if dataset_path: dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) - # if Fields.stats not in columns: + if Fields.stats not in columns: - # def process_batch_arrow(table: pa.Table) -> pa.Table: - # new_column_data = [{} for _ in range(len(table))] - # new_talbe = table.append_column(Fields.stats, [new_column_data]) - # return new_talbe + def process_batch_arrow(table: pa.Table) -> pa.Table: + new_column_data = [{} for _ in range(len(table))] + new_talbe = table.append_column(Fields.stats, [new_column_data]) + return new_talbe - # dataset = dataset.map_batches(process_batch_arrow, - # batch_format='pyarrow') + dataset = dataset.map_batches(process_batch_arrow, + batch_format='pyarrow') return dataset diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index 20e998e0f..ba87edda9 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -7,13 +7,14 @@ import ray import numpy as np import pyarrow as pa +import pyarrow.parquet as pq import regex from loguru import logger from pydantic import Field, PositiveInt from typing_extensions import Annotated from typing import List, Union -from data_juicer.utils.constant import HashKeys +from data_juicer.utils.constant import HashKeys, Fields from data_juicer.utils.model_utils import prepare_sentencepiece_model from ..base_op import OPERATORS, Deduplicator @@ -54,6 +55,12 @@ def get_edges(self, key): @ray.remote(scheduling_strategy="SPREAD") class BTSUnionFind: + """ + A distributed implementation of Union-Find with load balancing. + + The original paper on BTS Union-Find is available at: + https://ieeexplore.ieee.org/document/10598116 + """ def __init__( self, union_threshold, @@ -438,6 +445,7 @@ def tokenization_func(text): self.tmp_file_name = os.path.join( os.getcwd(), tmp_file_name, str(uuid.uuid4()) ) + os.makedirs(self.tmp_file_name) empty_hash_value = np.full( (self.num_rows_per_band,), @@ -577,16 +585,24 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: HashKeys.uid, pa.array(list(uid_list)) ) - return new_table + if not new_table[Fields.stats][0].as_py(): + columns_to_keep = [ + name + for name in new_table.column_names + if name != Fields.stats + ] + new_table = new_table.select(columns_to_keep) + pq.write_table( + new_table, + os.path.join(self.tmp_file_name, f'{min_id}.parquet') + ) + return pa.Table.from_arrays([]) dataset.map_batches( minhash_with_uid, batch_format='pyarrow', zero_copy_batch=True, - ).write_parquet( - self.tmp_file_name, - force_ascii=False - ) # TODO: balance file size + ).materialize() dataset = ray.data.read_parquet(self.tmp_file_name) end_time = time.time() print(f'MinHash time = {end_time - start_time}')