diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py index 5ffa6eb90..9ec12a35e 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor.py @@ -209,6 +209,9 @@ def run(self, load_data_np=None): desc=op_name + '_compute_stats') if self.cfg.use_checkpoint: prev = dataset + if op.stats_export_path is not None: + self.exporter.export_compute_stats( + dataset, op.stats_export_path) tmp = dataset.filter(op.process, num_proc=self.cfg.np, desc=op_name + '_process') diff --git a/data_juicer/core/exporter.py b/data_juicer/core/exporter.py index a8c7c35f9..fe74aacc5 100644 --- a/data_juicer/core/exporter.py +++ b/data_juicer/core/exporter.py @@ -195,6 +195,18 @@ def export(self, dataset): self._export_impl(dataset, self.export_path, self.suffix, self.export_stats) + def export_compute_stats(self, dataset, export_path): + """ + Export method for saving compute status in filters + """ + keep_stats_in_res_ds = self.keep_stats_in_res_ds + self.keep_stats_in_res_ds = True + self._export_impl(dataset, + export_path, + self.suffix, + export_stats=False) + self.keep_stats_in_res_ds = keep_stats_in_res_ds + @staticmethod def to_jsonl(dataset, export_path, num_proc=1, **kwargs): """ diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index a9ef0b908..d42d72f95 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -188,6 +188,9 @@ def process_batch_arrow(table: pa.Table) -> pa.Table: else: dataset = dataset.map(op.compute_stats, num_gpus=num_gpus) + if op.stats_export_path is not None: + dataset.write_json(op.stats_export_path, + force_ascii=False) dataset = dataset.filter(op.process) else: logger.error( diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index e5464619a..b5b2e79d9 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -153,6 +153,7 @@ def __init__(self, *args, **kwargs): from data_juicer.core.data import wrap_func_with_nested_access self.compute_stats = wrap_func_with_nested_access(self.compute_stats) + self.stats_export_path = kwargs.get('stats_export_path', None) def compute_stats(self, sample, context=False): """ diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index f9411ee26..741d7d9a1 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -64,6 +64,7 @@ def test_yaml_cfg_file(self): 'video_key': 'videos', 'accelerator': 'cpu', 'spec_numprocs': 0, + 'stats_export_path': None, 'cpu_required': 1, 'mem_required': 0, 'use_actor': False, @@ -130,6 +131,7 @@ def test_mixture_cfg(self): 'video_key': 'videos', 'accelerator': 'cpu', 'spec_numprocs': 0, + 'stats_export_path': None, 'cpu_required': 1, 'mem_required': 0, 'use_actor': False, @@ -146,6 +148,7 @@ def test_mixture_cfg(self): 'video_key': 'videos', 'accelerator': 'cpu', 'spec_numprocs': 0, + 'stats_export_path': None, 'cpu_required': 1, 'mem_required': 0, 'use_actor': False, @@ -162,6 +165,7 @@ def test_mixture_cfg(self): 'video_key': 'videos', 'accelerator': 'cpu', 'spec_numprocs': 0, + 'stats_export_path': None, 'cpu_required': 1, 'mem_required': 0, 'use_actor': False, @@ -178,6 +182,7 @@ def test_mixture_cfg(self): 'video_key': 'videos', 'accelerator': 'cpu', 'spec_numprocs': 0, + 'stats_export_path': None, 'cpu_required': 1, 'mem_required': 0, 'use_actor': False, @@ -194,6 +199,7 @@ def test_mixture_cfg(self): 'video_key': 'videos', 'accelerator': 'cpu', 'spec_numprocs': 0, + 'stats_export_path': None, 'cpu_required': 1, 'mem_required': 0, 'use_actor': False,