Skip to content

Commit

Permalink
fix in Fields.stats
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyushuo committed Dec 20, 2024
1 parent 8dda1aa commit 9d893be
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 25 deletions.
20 changes: 10 additions & 10 deletions data_juicer/core/ray_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,8 @@ def set_dataset_to_absolute_path(dataset, dataset_path, cfg):


def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset:
columns = dataset.columns()
if dataset_path:
dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg)
if Fields.stats not in columns:

def process_batch_arrow(table: pa.Table) -> pa.Table:
new_column_data = [{} for _ in range(len(table))]
new_talbe = table.append_column(Fields.stats, [new_column_data])
return new_talbe

dataset = dataset.map_batches(process_batch_arrow,
batch_format='pyarrow')
return dataset


Expand Down Expand Up @@ -123,6 +113,16 @@ def _run_single_op(self, op):
batch_format='pyarrow',
num_gpus=num_gpus)
elif isinstance(op, Filter):
columns = self.data.columns()
if Fields.stats not in columns:

def process_batch_arrow(table: pa.Table) -> pa.Table:
new_column_data = [{} for _ in range(len(table))]
new_talbe = table.append_column(Fields.stats, [new_column_data])
return new_talbe

self.data = self.data.map_batches(process_batch_arrow,
batch_format='pyarrow')
self.data = self.data.map_batches(op.compute_stats,
batch_size=batch_size,
batch_format='pyarrow',
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/core/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ def run(self, load_data_np=None):
logger.info('Processing data...')
tstart = time.time()
dataset.process(ops)
tend = time.time()
logger.info(f'All Ops are done in {tend - tstart:.3f}s.')

# 4. data export
logger.info('Exporting dataset to disk...')
dataset.data.write_json(self.cfg.export_path, force_ascii=False)
tend = time.time()
logger.info(f'All Ops are done in {tend - tstart:.3f}s.')
return dataset
18 changes: 5 additions & 13 deletions data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,24 +585,16 @@ def minhash_with_uid(table: pa.Table) -> pa.Table:
HashKeys.uid,
pa.array(list(uid_list))
)
if not new_table[Fields.stats][0].as_py():
columns_to_keep = [
name
for name in new_table.column_names
if name != Fields.stats
]
new_table = new_table.select(columns_to_keep)
pq.write_table(
new_table,
os.path.join(self.tmp_file_name, f'{min_id}.parquet')
)
return pa.Table.from_arrays([])
return new_table

dataset.map_batches(
minhash_with_uid,
batch_format='pyarrow',
zero_copy_batch=True,
).materialize()
).write_parquet(
self.tmp_file_name,
force_ascii=False
)
dataset = ray.data.read_parquet(self.tmp_file_name)
end_time = time.time()
print(f'MinHash time = {end_time - start_time}')
Expand Down

0 comments on commit 9d893be

Please sign in to comment.