diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index c6eef4329..a3611c56f 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -89,6 +89,20 @@ def init_configs(args=None): 'due to the IO blocking, especially for very large datasets. ' 'When this happens, False is a better choice, although it takes ' 'more time.') + parser.add_argument( + '--keep_stats_in_res_ds', + type=bool, + default=False, + help='Whether to keep the computed stats in the result dataset. If ' + 'it\'s False, the intermediate fields to store the stats ' + 'computed by Filters will be removed. Default: False.') + parser.add_argument( + '--keep_hashes_in_res_ds', + type=bool, + default=False, + help='Whether to keep the computed hashes in the result dataset. If ' + 'it\'s False, the intermediate fields to store the hashes ' + 'computed by Deduplicators will be removed. Default: False.') parser.add_argument('--np', type=PositiveInt, default=4, diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py index b724f7106..93d4cb30b 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor.py @@ -66,9 +66,13 @@ def __init__(self, cfg=None): # prepare exporter and check export path suffix logger.info('Preparing exporter...') - self.exporter = Exporter(self.cfg.export_path, - self.cfg.export_shard_size, - self.cfg.export_in_parallel, self.cfg.np) + self.exporter = Exporter( + self.cfg.export_path, + self.cfg.export_shard_size, + self.cfg.export_in_parallel, + self.cfg.np, + keep_stats_in_res_ds=self.cfg.keep_stats_in_res_ds, + keep_hashes_in_res_ds=self.cfg.keep_hashes_in_res_ds) # setup tracer self.open_tracer = self.cfg.open_tracer diff --git a/data_juicer/core/exporter.py b/data_juicer/core/exporter.py index 7377e5225..a8c7c35f9 100644 --- a/data_juicer/core/exporter.py +++ b/data_juicer/core/exporter.py @@ -3,7 +3,7 @@ from loguru import logger -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, HashKeys class Exporter: @@ -21,6 +21,8 @@ def __init__(self, export_in_parallel=True, num_proc=1, export_ds=True, + keep_stats_in_res_ds=False, + keep_hashes_in_res_ds=False, export_stats=True): """ Initialization method. @@ -31,12 +33,18 @@ def __init__(self, to a single file. :param num_proc: number of process to export the dataset. :param export_ds: whether to export the dataset contents. + :param keep_stats_in_res_ds: whether to keep stats in the result + dataset. + :param keep_hashes_in_res_ds: whether to keep hashes in the result + dataset. :param export_stats: whether to export the stats of dataset. """ self.export_path = export_path self.export_shard_size = export_shard_size self.export_in_parallel = export_in_parallel self.export_ds = export_ds + self.keep_stats_in_res_ds = keep_stats_in_res_ds + self.keep_hashes_in_res_ds = keep_hashes_in_res_ds self.export_stats = export_stats self.suffix = self._get_suffix(export_path) self.num_proc = num_proc @@ -98,8 +106,31 @@ def _export_impl(self, dataset, export_path, suffix, export_stats=True): :param export_stats: whether to export stats of dataset. :return: """ + if Fields.stats in dataset.features and export_stats: + # export stats of datasets into a single file. + logger.info('Exporting computed stats into a single file...') + ds_stats = dataset.select_columns(Fields.stats) + stats_file = export_path.replace('.' + suffix, '_stats.jsonl') + Exporter.to_jsonl( + ds_stats, + stats_file, + num_proc=self.num_proc if self.export_in_parallel else 1) + if self.export_ds: # fetch the corresponding export method according to the suffix + if not self.keep_stats_in_res_ds: + extra_fields = {Fields.stats} + feature_fields = set(dataset.features.keys()) + removed_fields = extra_fields.intersection(feature_fields) + dataset = dataset.remove_columns(removed_fields) + if not self.keep_hashes_in_res_ds: + extra_fields = { + HashKeys.hash, HashKeys.minhash, HashKeys.simhash, + HashKeys.imagehash + } + feature_fields = set(dataset.features.keys()) + removed_fields = extra_fields.intersection(feature_fields) + dataset = dataset.remove_columns(removed_fields) export_method = Exporter._router()[suffix] if self.export_shard_size <= 0: # export the whole dataset into one single file. @@ -154,15 +185,6 @@ def _export_impl(self, dataset, export_path, suffix, export_stats=True): pool.close() pool.join() - if Fields.stats in dataset.features and export_stats: - # export stats of datasets into a single file. - ds_stats = dataset.select_columns(Fields.stats) - stats_file = export_path.replace('.' + suffix, '_stats.jsonl') - Exporter.to_jsonl( - ds_stats, - stats_file, - num_proc=self.num_proc if self.export_in_parallel else 1) - def export(self, dataset): """ Export method for a dataset.