From 74e72364dd8c0d4d08c8a859cfab6b71c62b7387 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Mon, 18 Dec 2023 14:50:49 +0800 Subject: [PATCH 1/5] - implemented an auto-hpo tool for data-recipes based on 3-sigma tool - added a helper class to build the mapping from OPs' op_name to their used stats_key - added a config exporter --- data_juicer/config/config.py | 45 +++++++++++- data_juicer/utils/constant.py | 71 ++++++++++++++++++- tools/hpo/README.md | 29 ++++++-- tools/hpo/README_ZH.md | 33 +++++++-- tools/hpo/execute_hpo_3sigma.py | 69 ++++++++++++++++++ .../{execute_hpo.py => execute_hpo_wandb.py} | 0 6 files changed, 233 insertions(+), 14 deletions(-) create mode 100644 tools/hpo/execute_hpo_3sigma.py rename tools/hpo/{execute_hpo.py => execute_hpo_wandb.py} (100%) diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index a3611c56f..4835e2d4d 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -13,6 +13,9 @@ from data_juicer.utils.logger_utils import setup_logger from data_juicer.utils.mm_utils import SpecialTokens +global_cfg = None +global_parser = None + def init_configs(args=None): """ @@ -37,6 +40,11 @@ def init_configs(args=None): type=str, help='Path to a configuration file when using auto-HPO tool.', required=False) + parser.add_argument( + '--path_3sigma_recipe', + type=str, + help='Path to save a configuration file when using 3-sigma tool.', + required=False) # basic global paras with extended type hints # e.g., files can be mode include flags @@ -294,6 +302,10 @@ def init_configs(args=None): # show the final config tables before the process started display_config(cfg) + global global_cfg, global_parser + global_cfg = cfg + global_parser = parser + return cfg except ArgumentError: logger.error('Config initialization failed') @@ -371,7 +383,7 @@ def init_setup_from_cfg(cfg): f'variable HF_DATASETS_CACHE.') config.HF_DATASETS_CACHE = cfg.ds_cache_dir else: - cfg.ds_cache_dir = config.HF_DATASETS_CACHE + cfg.ds_cache_dir = str(config.HF_DATASETS_CACHE) # if there is suffix_filter op, turn on the add_suffix flag cfg.add_suffix = False @@ -478,6 +490,37 @@ def display_config(cfg): print(table) +def export_config(cfg, path, format='yaml', skip_none=True, skip_check=True, + overwrite=False, multifile=True, branch=None): + """ + save the config object, some params are from jsonargparse + :param cfg: cfg object to save (Namespace type) + :param path: the save path + :param format: 'yaml', 'json', 'json_indented', 'parser_mode' + :param skip_none: Whether to exclude entries whose value is None. + :param skip_check: Whether to skip parser checking. + :param overwrite: Whether to overwrite existing files. + :param multifile: Whether to save multiple config files + by using the __path__ metas. + :param branch: + :return: + """ + # remove ops outside the process list for better displaying + cfg_to_export = cfg.clone() + for op in OPERATORS.modules.keys(): + _ = cfg_to_export.pop(op) + + global global_parser + if not global_parser: + init_configs() # enable the customized type parser + global_parser.save( + cfg=cfg_to_export, path=path, format=format, skip_none=skip_none, + skip_check=skip_check, overwrite=overwrite, multifile=multifile, + branch=branch) + + logger.info(f'Saved the configuration in {path}') + + def merge_config(ori_cfg, new_cfg: Dict): """ Merge configuration from new_cfg into ori_cfg diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 1391968de..c4f43fae7 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -1,3 +1,9 @@ +import copy +import inspect +import os + +from loguru import logger + DEFAULT_PREFIX = '__dj__' @@ -8,7 +14,66 @@ class Fields(object): suffix = DEFAULT_PREFIX + 'suffix__' -class StatsKeys(object): +class StatsKeysMeta(type): + """ + a helper class to track the mapping from OP's name to its used stats_keys + + e.g., # once the AlphanumericFilter's compute_stats method has been called + res = TrackingDescriptor.get_access_log() + print(res) # {"AlphanumericFilter": ["alnum_ratio", "alpha_token_ratio"]} + """ + _accessed_by = {} + + def __getattr__(cls, attr): + caller_class = inspect.currentframe().f_back.f_globals['__name__'] + # no need to track the parent classes + caller_class = caller_class.split('.')[-1] + stat_key = getattr(cls._constants_class, attr) + if caller_class not in cls._accessed_by: + cls._accessed_by[caller_class] = set() + if stat_key not in cls._accessed_by[caller_class]: + cls._accessed_by[caller_class].add(stat_key) + return stat_key + + def get_access_log(cls, dj_cfg=None): + if cls._accessed_by: + return cls._accessed_by + elif dj_cfg: + tmp_dj_cfg = copy.deepcopy(dj_cfg) + # the access has been skipped due to the use of cache + # we will using a temp data sample to get the access log + if os.path.exists(dj_cfg.dataset_path) and \ + 'jsonl' in dj_cfg.dataset_path: + logger.info( + 'Begin to track the usage of ops with a dummy data sample') + + # load the first line as tmp_data + tmp_f_name = dj_cfg.dataset_path.\ + replace('.jsonl', '.tmp.jsonl') + with open(dj_cfg.dataset_path, 'r') as orig_file, \ + open(tmp_f_name, 'w') as tmp_file: + first_line = orig_file.readline() + tmp_file.write(first_line) + + tmp_dj_cfg.dataset_path = tmp_f_name + tmp_dj_cfg.use_cache = False + tmp_dj_cfg.use_checkpoint = False + + from data_juicer.core import Analyser + tmp_analyzer = Analyser(tmp_dj_cfg) + tmp_analyzer.run() + + os.remove(tmp_f_name) + else: + raise NotImplementedError( + f'For now, the dummy data is supported for only jsonl type' + f'. Please check your config as {dj_cfg.dataset_path} is ' + f'either not existed or in jsonl type.') + + return cls._accessed_by + + +class StatsKeysConstant(object): # text alpha_token_ratio = 'alpha_token_ratio' alnum_ratio = 'alnum_ratio' @@ -41,6 +106,10 @@ class StatsKeys(object): image_text_matching_score = 'image_text_matching_score' +class StatsKeys(object, metaclass=StatsKeysMeta): + _constants_class = StatsKeysConstant + + class HashKeys(object): hash = DEFAULT_PREFIX + 'hash' minhash = DEFAULT_PREFIX + 'minhash' diff --git a/tools/hpo/README.md b/tools/hpo/README.md index 01a3735f5..d59b81f8e 100644 --- a/tools/hpo/README.md +++ b/tools/hpo/README.md @@ -1,6 +1,23 @@ # Hyper-parameter Optimization for Data Recipe -## Auto-HPO +## Auto-HPO based on 3-Sigma principles +A simple automatic hyper-parameter optimization method for data recipes is to assume that outlier data is harmful to training. +We thus can introduce the 3-sigma principle to automatically determine the hyper-parameters and filter the data. + +Specifically, assuming that a certain analysis dimension of the original data obeys a normal distribution and has random errors, we can set the upper and lower bounds of the filtering OP in this dimension to three times the standard deviation based on the statistics produced by the DataJuicer's Analyzer. + +$$P(|x-\mu| > 3\sigma) \leq 0.003$$ + +To automate this process, we provide the tool which can be used as follows: +```shell +# cd tools/hpo +python execute_hpo_3sigma.py --config + +#e.g., +python execute_hpo_3sigma.py --config configs/process.yaml +``` + +## Auto-HPO with WandB We incorporate an automated HPO tool, WandB [Sweep](https://docs.wandb.ai/guides/sweeps), into Data-Juicer to streamline the finding of good data processing hyper-parameters. With this tool, users can investigate correlations and importance scores of @@ -11,7 +28,7 @@ a large room to explore. Feel free to provide more suggestions, discussion, and contribution via new PRs! -## Prerequisite +### Prerequisite You need to install data-juicer first. Besides, the tool leverages WandB, install it via `pip install wandb`. Before using this tool, you need to run @@ -26,17 +43,17 @@ wandb login --host -## Usage and Customization +### Usage and Customization Given a data recipe, characterized by specified configuration file -``, you can use `execute_hpo.py` to search the +``, you can use `execute_hpo_wandb.py` to search the hyper-parameter space defined by ``. ```shell # cd tools/hpo -python execute_hpo.py --config --hpo_config +python execute_hpo_wandb.py --config --hpo_config # e.g., -python execute_hpo.py --config configs/process.yaml --hpo_config configs/quality_score_hpo.yaml +python execute_hpo_wandb.py --config configs/process.yaml --hpo_config configs/quality_score_hpo.yaml ``` For the configuration for data recipe, i.e., ``, diff --git a/tools/hpo/README_ZH.md b/tools/hpo/README_ZH.md index b2437302b..de7a46d59 100644 --- a/tools/hpo/README_ZH.md +++ b/tools/hpo/README_ZH.md @@ -1,6 +1,27 @@ # 数据菜谱的自动化超参优化 -## Auto-HPO +## 基于3-Sigma原则进行Auto-HPO +一种简单的数据菜谱自动调参方法是假设outlier数据对训练有害,那么我们可以引入3-sigma原则来自动确定超参,过滤数据。具体来说,假设原始数据的某个分析维度服从正态分布且存在随机误差,我们可以基于Analyzer产出的stats,在该维度上 +把算子过滤的上下界设为三倍标准差。 + +$$P(|x-\mu| > 3\sigma) \leq 0.003$$ + +为了自动化该过程,我们提供了相应工具: +```shell +# cd tools/hpo +# usage 1: do not save the refined recipe +python execute_hpo_3sigma.py --config +# usage 2: save the refined recipe at the given path +python execute_hpo_3sigma.py --config --path_3sigma_recipe + +# e.g., usage 1 +python execute_hpo_3sigma.py --config configs/process.yaml +# e.g., usage 2 +python execute_hpo_3sigma.py --config configs/process.yaml --path_3sigma_recipe configs/process_3sigma.yaml +``` + + +## 基于WandB进行Auto-HPO 我们将自动化 HPO (hyper-parameters optimization) 工具 WandB [Sweep](https://docs.wandb.ai/guides/sweeps) 结合到 Data-Juicer 中,以简化改良数据处理超参数的过程。 @@ -11,7 +32,7 @@ Data-Juicer 中,以简化改良数据处理超参数的过程。 并通过新的 PR 做出贡献! -## 前置条件 +### 前置条件 您需要先安装 data-juicer。 此外,该工具利用了 WandB,通过`pip install wandb`安装它。 在使用此工具之前,您需要运行`wandb login`并输入您的 WandB @@ -25,17 +46,17 @@ wandb login --host -## 使用和定制化 +### 使用和定制化 -给定一个数据配方,以指定的配置文件所定义``,您可以使用 `execute_hpo.py` 来搜索 +给定一个数据配方,以指定的配置文件所定义``,您可以使用 `execute_hpo_wandb.py` 来搜索 由``定义的超参数空间。 ```shell # cd tools/hpo -python execute_hpo.py --config --hpo_config +python execute_hpo_wandb.py --config --hpo_config # e.g., -python execute_hpo.py --config configs/process.yaml --hpo_config configs/quality_score_hpo.yaml +python execute_hpo_wandb.py --config configs/process.yaml --hpo_config configs/quality_score_hpo.yaml ``` 对于数据菜谱的配置,即``, diff --git a/tools/hpo/execute_hpo_3sigma.py b/tools/hpo/execute_hpo_3sigma.py new file mode 100644 index 000000000..5427027db --- /dev/null +++ b/tools/hpo/execute_hpo_3sigma.py @@ -0,0 +1,69 @@ +import copy +import sys + +from loguru import logger + +from data_juicer.config import export_config, init_configs +from data_juicer.core import Analyser, Executor +from data_juicer.utils.constant import StatsKeys + + +@logger.catch +def main(): + + path_3sigma_recipe = None + for i in range(len(sys.argv) - 1): + if sys.argv[i] == '--path_3sigma_recipe': + path_3sigma_recipe = sys.argv[i + 1] + + # 1. analyze using the given initial recipe + cfg = init_configs() + logger.info('Begin to analyze data using the given initial recipe') + + analyser = Analyser(cfg) + analyser.run() + df = analyser.overall_result + # get the mapping from op_name to their mu and sigma + mean_series = df[df.index == 'mean'] + stats_key_to_mean = mean_series.iloc[0, :].to_dict() + std_series = df[df.index == 'std'] + stats_key_to_std = std_series.iloc[0, :].to_dict() + + # 2. adjust the hyper-parameters of the given recipe with 3-sigma rule + logger.info('Begin to modify the recipe with 3-sigma rule') + op_name_to_stats_key = StatsKeys.get_access_log(dj_cfg=cfg) + for process in cfg.process: + op_name, args = list(process.items())[0] + temp_args = copy.deepcopy(args) + stats_keys = op_name_to_stats_key[op_name] + for stats_key in stats_keys: + if stats_key in stats_key_to_mean: + for arg_name in temp_args.keys(): + new_val = None + if 'min' in arg_name: + new_val = stats_key_to_mean[stats_key] - \ + 3 * stats_key_to_std[stats_key] + if 'max' in arg_name: + new_val = stats_key_to_mean[stats_key] + \ + 3 * stats_key_to_std[stats_key] + if new_val is not None: + logger.info(f'Using 3-sigma rule, changed para ' + f'{arg_name}={args[arg_name]} into ' + f'{arg_name}={new_val}') + args[arg_name] = new_val + + if path_3sigma_recipe: + export_config(cfg, path_3sigma_recipe) + + # 3. process the data using the refined recipe + logger.info('Begin to process the data with refined recipe') + if cfg.executor_type == 'default': + executor = Executor(cfg) + elif cfg.executor_type == 'ray': + from data_juicer.core.ray_executor import RayExecutor + executor = RayExecutor(cfg) + executor.run() + + +if __name__ == '__main__': + main() diff --git a/tools/hpo/execute_hpo.py b/tools/hpo/execute_hpo_wandb.py similarity index 100% rename from tools/hpo/execute_hpo.py rename to tools/hpo/execute_hpo_wandb.py From 325a8559bfaa0fe269f5db58e9ddef00df2fc376 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Tue, 19 Dec 2023 14:31:02 +0800 Subject: [PATCH 2/5] - added support for jsonl.zst type when calling get_access_log() --- data_juicer/utils/constant.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index c4f43fae7..a58bc3b34 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -2,6 +2,7 @@ import inspect import os +import zstandard as zstd from loguru import logger DEFAULT_PREFIX = '__dj__' @@ -43,16 +44,38 @@ def get_access_log(cls, dj_cfg=None): # the access has been skipped due to the use of cache # we will using a temp data sample to get the access log if os.path.exists(dj_cfg.dataset_path) and \ - 'jsonl' in dj_cfg.dataset_path: + ('jsonl' in dj_cfg.dataset_path or + 'jsonl.zst' in dj_cfg.dataset_path): logger.info( 'Begin to track the usage of ops with a dummy data sample') # load the first line as tmp_data - tmp_f_name = dj_cfg.dataset_path.\ - replace('.jsonl', '.tmp.jsonl') - with open(dj_cfg.dataset_path, 'r') as orig_file, \ - open(tmp_f_name, 'w') as tmp_file: - first_line = orig_file.readline() + tmp_f_name = None + first_line = None + if 'jsonl.zst' in dj_cfg.dataset_path: + tmp_f_name = dj_cfg.dataset_path. \ + replace('.jsonl.zst', '.tmp.jsonl') + # Open the file in binary mode and + # create a Zstandard decompression context + with open(dj_cfg.dataset_path, 'rb') as compressed_file: + dctx = zstd.ZstdDecompressor() + # Create a stream reader for the file and decode the + # first line + with dctx.stream_reader(compressed_file) as reader: + first_line_bytes = reader.readline() + # Assuming the file is encoded in UTF-8 + first_line = first_line_bytes.decode('utf-8') + elif 'jsonl' in dj_cfg.dataset_path: + tmp_f_name = dj_cfg.dataset_path. \ + replace('.jsonl', '.tmp.jsonl') + with open(dj_cfg.dataset_path, 'r') as orig_file: + first_line = orig_file.readline() + + assert tmp_f_name is not None and first_line is not None, \ + 'error when loading the first line, when ' \ + f'dj_cfg.dataset_path={dj_cfg.dataset_path}' + + with open(tmp_f_name, 'w') as tmp_file: tmp_file.write(first_line) tmp_dj_cfg.dataset_path = tmp_f_name From 32f40dd354ccb18824f9153655e45216265733f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Tue, 19 Dec 2023 15:26:20 +0800 Subject: [PATCH 3/5] - added support for jsonl.zst type when calling get_access_log() --- data_juicer/utils/constant.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index a58bc3b34..1eeccf998 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -1,5 +1,6 @@ import copy import inspect +import io import os import zstandard as zstd @@ -62,9 +63,9 @@ def get_access_log(cls, dj_cfg=None): # Create a stream reader for the file and decode the # first line with dctx.stream_reader(compressed_file) as reader: - first_line_bytes = reader.readline() - # Assuming the file is encoded in UTF-8 - first_line = first_line_bytes.decode('utf-8') + text_stream = io.TextIOWrapper( + reader, encoding='utf-8') + first_line = text_stream.readline() elif 'jsonl' in dj_cfg.dataset_path: tmp_f_name = dj_cfg.dataset_path. \ replace('.jsonl', '.tmp.jsonl') From c2b3b3879557f5af45068e1caa428a8c5c211c66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Tue, 19 Dec 2023 15:59:10 +0800 Subject: [PATCH 4/5] - added support for jsonl.zst type when calling get_access_log() --- tools/hpo/execute_hpo_3sigma.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/hpo/execute_hpo_3sigma.py b/tools/hpo/execute_hpo_3sigma.py index 5427027db..bb4df51d6 100644 --- a/tools/hpo/execute_hpo_3sigma.py +++ b/tools/hpo/execute_hpo_3sigma.py @@ -35,6 +35,9 @@ def main(): for process in cfg.process: op_name, args = list(process.items())[0] temp_args = copy.deepcopy(args) + if op_name not in op_name_to_stats_key: + # skip the op such as `clean_email_mapper` + continue stats_keys = op_name_to_stats_key[op_name] for stats_key in stats_keys: if stats_key in stats_key_to_mean: From fde71acf06e6e4ef9bef1990142dd1d7a416238f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Tue, 19 Dec 2023 19:48:56 +0800 Subject: [PATCH 5/5] fix according to yilun's comments --- data_juicer/analysis/column_wise_analysis.py | 16 ++++++++++------ data_juicer/analysis/overall_analysis.py | 8 +++++--- data_juicer/config/config.py | 7 +++---- data_juicer/core/analyser.py | 13 +++++++++---- data_juicer/utils/constant.py | 3 ++- tools/hpo/README.md | 13 +++++++++---- tools/hpo/execute_hpo_3sigma.py | 5 +++-- 7 files changed, 41 insertions(+), 24 deletions(-) diff --git a/data_juicer/analysis/column_wise_analysis.py b/data_juicer/analysis/column_wise_analysis.py index 5fe732cbc..194bca503 100644 --- a/data_juicer/analysis/column_wise_analysis.py +++ b/data_juicer/analysis/column_wise_analysis.py @@ -58,7 +58,8 @@ def __init__(self, dataset, output_path, overall_result=None, - save_stats_in_one_file=True): + save_stats_in_one_file=True, + ): """ Initialization method :param dataset: the dataset to be analysed @@ -80,7 +81,7 @@ def __init__(self, self.save_stats_in_one_file = save_stats_in_one_file - def analyse(self, show_percentiles=False, show=False): + def analyse(self, show_percentiles=False, show=False, skip_export=False): """ Apply analysis and draw the analysis figure for stats. @@ -88,6 +89,7 @@ def analyse(self, show_percentiles=False, show=False): each sub-figure. If it's true, there will be several red lines to indicate the quantiles of the stats distributions :param show: whether to show in a single window after drawing + :param skip_export: whether save the results into disk :return: """ # number of sub-figures for each stat. There are histogram and box plot @@ -164,9 +166,10 @@ def analyse(self, show_percentiles=False, show=False): else: axes = None - self.draw_hist( - axes, data, - os.path.join(self.output_path, f'{column_name}-hist.png')) + if not skip_export: + self.draw_hist( + axes, data, os.path.join( + self.output_path, f'{column_name}-hist.png')) # add a title to the figure of this stat if self.save_stats_in_one_file: @@ -176,7 +179,8 @@ def analyse(self, show_percentiles=False, show=False): if self.save_stats_in_one_file: fig = plt.gcf() - fig.savefig(os.path.join(self.output_path, 'all-stats.png')) + if not skip_export: + fig.savefig(os.path.join(self.output_path, 'all-stats.png')) if show: plt.show() else: diff --git a/data_juicer/analysis/overall_analysis.py b/data_juicer/analysis/overall_analysis.py index acf8539ff..b68b4551d 100644 --- a/data_juicer/analysis/overall_analysis.py +++ b/data_juicer/analysis/overall_analysis.py @@ -58,13 +58,14 @@ def refine_single_column(self, col): col = col.explode().infer_objects() return col - def analyse(self, percentiles=[], num_proc=1): + def analyse(self, percentiles=[], num_proc=1, skip_export=False): """ Apply overall analysis on the whole dataset based on the describe method of pandas. :param percentiles: percentiles to analyse :param num_proc: number of processes to analyse the dataset + :param skip_export: whether export the results to disk :return: the overall analysis result. """ # merge default and customized percentiles and get overall information @@ -87,7 +88,8 @@ def analyse(self, percentiles=[], num_proc=1): overall = pd.DataFrame(result_cols).T # export to result report file - overall.to_csv(os.path.join(self.output_path, 'overall.csv')) - overall.to_markdown(os.path.join(self.output_path, 'overall.md')) + if not skip_export: + overall.to_csv(os.path.join(self.output_path, 'overall.csv')) + overall.to_markdown(os.path.join(self.output_path, 'overall.md')) return overall diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 4835e2d4d..d23857a4a 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -491,7 +491,7 @@ def display_config(cfg): def export_config(cfg, path, format='yaml', skip_none=True, skip_check=True, - overwrite=False, multifile=True, branch=None): + overwrite=False, multifile=True): """ save the config object, some params are from jsonargparse :param cfg: cfg object to save (Namespace type) @@ -502,7 +502,7 @@ def export_config(cfg, path, format='yaml', skip_none=True, skip_check=True, :param overwrite: Whether to overwrite existing files. :param multifile: Whether to save multiple config files by using the __path__ metas. - :param branch: + :return: """ # remove ops outside the process list for better displaying @@ -515,8 +515,7 @@ def export_config(cfg, path, format='yaml', skip_none=True, skip_check=True, init_configs() # enable the customized type parser global_parser.save( cfg=cfg_to_export, path=path, format=format, skip_none=skip_none, - skip_check=skip_check, overwrite=overwrite, multifile=multifile, - branch=branch) + skip_check=skip_check, overwrite=overwrite, multifile=multifile) logger.info(f'Saved the configuration in {path}') diff --git a/data_juicer/core/analyser.py b/data_juicer/core/analyser.py index 8636a2a53..903a3d8e3 100644 --- a/data_juicer/core/analyser.py +++ b/data_juicer/core/analyser.py @@ -63,11 +63,12 @@ def __init__(self, cfg=None): self.overall_single_plot_path = None self.analysis_path = os.path.join(self.cfg.work_dir, 'analysis') - def run(self, load_data_np=None): + def run(self, load_data_np=None, skip_export=False): """ Running the dataset analysis pipeline. :param load_data_np: number of workers when loading the dataset. + :param skip_export: whether export the results into disk :return: analysed dataset. """ # 1. format data @@ -120,14 +121,18 @@ def run(self, load_data_np=None): logger.info('Applying overall analysis on stats...') overall_analysis = OverallAnalysis(dataset, self.analysis_path) - self.overall_result = overall_analysis.analyse(num_proc=self.cfg.np) + self.overall_result = overall_analysis.analyse( + num_proc=self.cfg.np, skip_export=skip_export) + + logger.info(f'The overall analysis results are: {self.overall_result}') logger.info('Applying column-wise analysis on stats...') column_wise_analysis = ColumnWiseAnalysis( dataset, self.analysis_path, overall_result=self.overall_result, - save_stats_in_one_file=self.cfg.save_stats_in_one_file) - column_wise_analysis.analyse() + save_stats_in_one_file=self.cfg.save_stats_in_one_file, + ) + column_wise_analysis.analyse(skip_export=skip_export) return dataset diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 1eeccf998..6fa190263 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -85,7 +85,8 @@ def get_access_log(cls, dj_cfg=None): from data_juicer.core import Analyser tmp_analyzer = Analyser(tmp_dj_cfg) - tmp_analyzer.run() + # do not overwrite the true analysis results + tmp_analyzer.run(skip_export=True) os.remove(tmp_f_name) else: diff --git a/tools/hpo/README.md b/tools/hpo/README.md index d59b81f8e..b840636f1 100644 --- a/tools/hpo/README.md +++ b/tools/hpo/README.md @@ -11,10 +11,15 @@ $$P(|x-\mu| > 3\sigma) \leq 0.003$$ To automate this process, we provide the tool which can be used as follows: ```shell # cd tools/hpo -python execute_hpo_3sigma.py --config - -#e.g., -python execute_hpo_3sigma.py --config configs/process.yaml +# usage 1: do not save the refined recipe +python execute_hpo_3sigma.py --config +# usage 2: save the refined recipe at the given path +python execute_hpo_3sigma.py --config --path_3sigma_recipe + +# e.g., usage 1 +python execute_hpo_3sigma.py --config configs/process.yaml +# e.g., usage 2 +python execute_hpo_3sigma.py --config configs/process.yaml --path_3sigma_recipe configs/process_3sigma.yaml ``` ## Auto-HPO with WandB diff --git a/tools/hpo/execute_hpo_3sigma.py b/tools/hpo/execute_hpo_3sigma.py index bb4df51d6..72851b5cc 100644 --- a/tools/hpo/execute_hpo_3sigma.py +++ b/tools/hpo/execute_hpo_3sigma.py @@ -49,8 +49,9 @@ def main(): if 'max' in arg_name: new_val = stats_key_to_mean[stats_key] + \ 3 * stats_key_to_std[stats_key] - if new_val is not None: - logger.info(f'Using 3-sigma rule, changed para ' + if new_val is not None and str(new_val) != 'nan': + logger.info(f'Using 3-sigma rule, for op {op_name}, ' + f'changed its para ' f'{arg_name}={args[arg_name]} into ' f'{arg_name}={new_val}') args[arg_name] = new_val