Skip to content

Commit

Permalink
+ add two arguments to control whether to keep stats or hashes in the…
Browse files Browse the repository at this point in the history
… result dataset
  • Loading branch information
HYLcool committed Nov 29, 2023
1 parent 886d9ea commit c5465de
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 13 deletions.
14 changes: 14 additions & 0 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ def init_configs(args=None):
'due to the IO blocking, especially for very large datasets. '
'When this happens, False is a better choice, although it takes '
'more time.')
parser.add_argument(
'--keep_stats_in_res_ds',
type=bool,
default=False,
help='Whether to keep the computed stats in the result dataset. If '
'it\'s False, the intermediate fields to store the stats '
'computed by Filters will be removed. Default: False.')
parser.add_argument(
'--keep_hashes_in_res_ds',
type=bool,
default=False,
help='Whether to keep the computed hashes in the result dataset. If '
'it\'s False, the intermediate fields to store the hashes '
'computed by Deduplicators will be removed. Default: False.')
parser.add_argument('--np',
type=PositiveInt,
default=4,
Expand Down
10 changes: 7 additions & 3 deletions data_juicer/core/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,13 @@ def __init__(self, cfg=None):

# prepare exporter and check export path suffix
logger.info('Preparing exporter...')
self.exporter = Exporter(self.cfg.export_path,
self.cfg.export_shard_size,
self.cfg.export_in_parallel, self.cfg.np)
self.exporter = Exporter(
self.cfg.export_path,
self.cfg.export_shard_size,
self.cfg.export_in_parallel,
self.cfg.np,
keep_stats_in_res_ds=self.cfg.keep_stats_in_res_ds,
keep_hashes_in_res_ds=self.cfg.keep_hashes_in_res_ds)

# setup tracer
self.open_tracer = self.cfg.open_tracer
Expand Down
42 changes: 32 additions & 10 deletions data_juicer/core/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from loguru import logger

from data_juicer.utils.constant import Fields
from data_juicer.utils.constant import Fields, HashKeys


class Exporter:
Expand All @@ -21,6 +21,8 @@ def __init__(self,
export_in_parallel=True,
num_proc=1,
export_ds=True,
keep_stats_in_res_ds=False,
keep_hashes_in_res_ds=False,
export_stats=True):
"""
Initialization method.
Expand All @@ -31,12 +33,18 @@ def __init__(self,
to a single file.
:param num_proc: number of process to export the dataset.
:param export_ds: whether to export the dataset contents.
:param keep_stats_in_res_ds: whether to keep stats in the result
dataset.
:param keep_hashes_in_res_ds: whether to keep hashes in the result
dataset.
:param export_stats: whether to export the stats of dataset.
"""
self.export_path = export_path
self.export_shard_size = export_shard_size
self.export_in_parallel = export_in_parallel
self.export_ds = export_ds
self.keep_stats_in_res_ds = keep_stats_in_res_ds
self.keep_hashes_in_res_ds = keep_hashes_in_res_ds
self.export_stats = export_stats
self.suffix = self._get_suffix(export_path)
self.num_proc = num_proc
Expand Down Expand Up @@ -98,8 +106,31 @@ def _export_impl(self, dataset, export_path, suffix, export_stats=True):
:param export_stats: whether to export stats of dataset.
:return:
"""
if Fields.stats in dataset.features and export_stats:
# export stats of datasets into a single file.
logger.info('Exporting computed stats into a single file...')
ds_stats = dataset.select_columns(Fields.stats)
stats_file = export_path.replace('.' + suffix, '_stats.jsonl')
Exporter.to_jsonl(
ds_stats,
stats_file,
num_proc=self.num_proc if self.export_in_parallel else 1)

if self.export_ds:
# fetch the corresponding export method according to the suffix
if not self.keep_stats_in_res_ds:
extra_fields = {Fields.stats}
feature_fields = set(dataset.features.keys())
removed_fields = extra_fields.intersection(feature_fields)
dataset = dataset.remove_columns(removed_fields)
if not self.keep_hashes_in_res_ds:
extra_fields = {
HashKeys.hash, HashKeys.minhash, HashKeys.simhash,
HashKeys.imagehash
}
feature_fields = set(dataset.features.keys())
removed_fields = extra_fields.intersection(feature_fields)
dataset = dataset.remove_columns(removed_fields)
export_method = Exporter._router()[suffix]
if self.export_shard_size <= 0:
# export the whole dataset into one single file.
Expand Down Expand Up @@ -154,15 +185,6 @@ def _export_impl(self, dataset, export_path, suffix, export_stats=True):
pool.close()
pool.join()

if Fields.stats in dataset.features and export_stats:
# export stats of datasets into a single file.
ds_stats = dataset.select_columns(Fields.stats)
stats_file = export_path.replace('.' + suffix, '_stats.jsonl')
Exporter.to_jsonl(
ds_stats,
stats_file,
num_proc=self.num_proc if self.export_in_parallel else 1)

def export(self, dataset):
"""
Export method for a dataset.
Expand Down

0 comments on commit c5465de

Please sign in to comment.