diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 621e68cd9..2966e75e8 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -58,18 +58,8 @@ def set_dataset_to_absolute_path(dataset, dataset_path, cfg): def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset: - columns = dataset.columns() if dataset_path: dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) - if Fields.stats not in columns: - - def process_batch_arrow(table: pa.Table) -> pa.Table: - new_column_data = [{} for _ in range(len(table))] - new_talbe = table.append_column(Fields.stats, [new_column_data]) - return new_talbe - - dataset = dataset.map_batches(process_batch_arrow, - batch_format='pyarrow') return dataset @@ -123,6 +113,16 @@ def _run_single_op(self, op): batch_format='pyarrow', num_gpus=num_gpus) elif isinstance(op, Filter): + columns = self.data.columns() + if Fields.stats not in columns: + + def process_batch_arrow(table: pa.Table) -> pa.Table: + new_column_data = [{} for _ in range(len(table))] + new_talbe = table.append_column(Fields.stats, [new_column_data]) + return new_talbe + + self.data = self.data.map_batches(process_batch_arrow, + batch_format='pyarrow') self.data = self.data.map_batches(op.compute_stats, batch_size=batch_size, batch_format='pyarrow', diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index 6b93fd3dd..f146ffc02 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -68,10 +68,10 @@ def run(self, load_data_np=None): logger.info('Processing data...') tstart = time.time() dataset.process(ops) - tend = time.time() - logger.info(f'All Ops are done in {tend - tstart:.3f}s.') # 4. data export logger.info('Exporting dataset to disk...') dataset.data.write_json(self.cfg.export_path, force_ascii=False) + tend = time.time() + logger.info(f'All Ops are done in {tend - tstart:.3f}s.') return dataset diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index ba87edda9..676aa0185 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -585,24 +585,16 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: HashKeys.uid, pa.array(list(uid_list)) ) - if not new_table[Fields.stats][0].as_py(): - columns_to_keep = [ - name - for name in new_table.column_names - if name != Fields.stats - ] - new_table = new_table.select(columns_to_keep) - pq.write_table( - new_table, - os.path.join(self.tmp_file_name, f'{min_id}.parquet') - ) - return pa.Table.from_arrays([]) + return new_table dataset.map_batches( minhash_with_uid, batch_format='pyarrow', zero_copy_batch=True, - ).materialize() + ).write_parquet( + self.tmp_file_name, + force_ascii=False + ) dataset = ray.data.read_parquet(self.tmp_file_name) end_time = time.time() print(f'MinHash time = {end_time - start_time}')