Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyushuo committed Dec 17, 2024
1 parent 62caefe commit 8dda1aa
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
14 changes: 7 additions & 7 deletions data_juicer/core/ray_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset:
columns = dataset.columns()
if dataset_path:
dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg)
# if Fields.stats not in columns:
if Fields.stats not in columns:

# def process_batch_arrow(table: pa.Table) -> pa.Table:
# new_column_data = [{} for _ in range(len(table))]
# new_talbe = table.append_column(Fields.stats, [new_column_data])
# return new_talbe
def process_batch_arrow(table: pa.Table) -> pa.Table:
new_column_data = [{} for _ in range(len(table))]
new_talbe = table.append_column(Fields.stats, [new_column_data])
return new_talbe

# dataset = dataset.map_batches(process_batch_arrow,
# batch_format='pyarrow')
dataset = dataset.map_batches(process_batch_arrow,
batch_format='pyarrow')
return dataset


Expand Down
28 changes: 22 additions & 6 deletions data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import ray
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import regex
from loguru import logger
from pydantic import Field, PositiveInt
from typing_extensions import Annotated
from typing import List, Union

from data_juicer.utils.constant import HashKeys
from data_juicer.utils.constant import HashKeys, Fields
from data_juicer.utils.model_utils import prepare_sentencepiece_model

from ..base_op import OPERATORS, Deduplicator
Expand Down Expand Up @@ -54,6 +55,12 @@ def get_edges(self, key):

@ray.remote(scheduling_strategy="SPREAD")
class BTSUnionFind:
"""
A distributed implementation of Union-Find with load balancing.
The original paper on BTS Union-Find is available at:
https://ieeexplore.ieee.org/document/10598116
"""
def __init__(
self,
union_threshold,
Expand Down Expand Up @@ -438,6 +445,7 @@ def tokenization_func(text):
self.tmp_file_name = os.path.join(
os.getcwd(), tmp_file_name, str(uuid.uuid4())
)
os.makedirs(self.tmp_file_name)

empty_hash_value = np.full(
(self.num_rows_per_band,),
Expand Down Expand Up @@ -577,16 +585,24 @@ def minhash_with_uid(table: pa.Table) -> pa.Table:
HashKeys.uid,
pa.array(list(uid_list))
)
return new_table
if not new_table[Fields.stats][0].as_py():
columns_to_keep = [
name
for name in new_table.column_names
if name != Fields.stats
]
new_table = new_table.select(columns_to_keep)
pq.write_table(
new_table,
os.path.join(self.tmp_file_name, f'{min_id}.parquet')
)
return pa.Table.from_arrays([])

dataset.map_batches(
minhash_with_uid,
batch_format='pyarrow',
zero_copy_batch=True,
).write_parquet(
self.tmp_file_name,
force_ascii=False
) # TODO: balance file size
).materialize()
dataset = ray.data.read_parquet(self.tmp_file_name)
end_time = time.time()
print(f'MinHash time = {end_time - start_time}')
Expand Down

0 comments on commit 8dda1aa

Please sign in to comment.