Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add an auto-hpo tool for data-recipes based on 3-sigma rule #140

Merged
merged 5 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions data_juicer/analysis/column_wise_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def __init__(self,
dataset,
output_path,
overall_result=None,
save_stats_in_one_file=True):
save_stats_in_one_file=True,
):
"""
Initialization method
:param dataset: the dataset to be analysed
Expand All @@ -80,14 +81,15 @@ def __init__(self,

self.save_stats_in_one_file = save_stats_in_one_file

def analyse(self, show_percentiles=False, show=False):
def analyse(self, show_percentiles=False, show=False, skip_export=False):
"""
Apply analysis and draw the analysis figure for stats.

:param show_percentiles: whether to show the percentile line in
each sub-figure. If it's true, there will be several red
lines to indicate the quantiles of the stats distributions
:param show: whether to show in a single window after drawing
:param skip_export: whether save the results into disk
:return:
"""
# number of sub-figures for each stat. There are histogram and box plot
Expand Down Expand Up @@ -164,9 +166,10 @@ def analyse(self, show_percentiles=False, show=False):
else:
axes = None

self.draw_hist(
axes, data,
os.path.join(self.output_path, f'{column_name}-hist.png'))
if not skip_export:
self.draw_hist(
axes, data, os.path.join(
self.output_path, f'{column_name}-hist.png'))

# add a title to the figure of this stat
if self.save_stats_in_one_file:
Expand All @@ -176,7 +179,8 @@ def analyse(self, show_percentiles=False, show=False):

if self.save_stats_in_one_file:
fig = plt.gcf()
fig.savefig(os.path.join(self.output_path, 'all-stats.png'))
if not skip_export:
fig.savefig(os.path.join(self.output_path, 'all-stats.png'))
if show:
plt.show()
else:
Expand Down
8 changes: 5 additions & 3 deletions data_juicer/analysis/overall_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,14 @@ def refine_single_column(self, col):
col = col.explode().infer_objects()
return col

def analyse(self, percentiles=[], num_proc=1):
def analyse(self, percentiles=[], num_proc=1, skip_export=False):
"""
Apply overall analysis on the whole dataset based on the describe
method of pandas.

:param percentiles: percentiles to analyse
:param num_proc: number of processes to analyse the dataset
:param skip_export: whether export the results to disk
:return: the overall analysis result.
"""
# merge default and customized percentiles and get overall information
Expand All @@ -87,7 +88,8 @@ def analyse(self, percentiles=[], num_proc=1):
overall = pd.DataFrame(result_cols).T

# export to result report file
overall.to_csv(os.path.join(self.output_path, 'overall.csv'))
overall.to_markdown(os.path.join(self.output_path, 'overall.md'))
if not skip_export:
overall.to_csv(os.path.join(self.output_path, 'overall.csv'))
overall.to_markdown(os.path.join(self.output_path, 'overall.md'))

return overall
44 changes: 43 additions & 1 deletion data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from data_juicer.utils.logger_utils import setup_logger
from data_juicer.utils.mm_utils import SpecialTokens

global_cfg = None
global_parser = None


def init_configs(args=None):
"""
Expand All @@ -37,6 +40,11 @@ def init_configs(args=None):
type=str,
help='Path to a configuration file when using auto-HPO tool.',
required=False)
parser.add_argument(
'--path_3sigma_recipe',
type=str,
help='Path to save a configuration file when using 3-sigma tool.',
required=False)

# basic global paras with extended type hints
# e.g., files can be mode include flags
Expand Down Expand Up @@ -294,6 +302,10 @@ def init_configs(args=None):
# show the final config tables before the process started
display_config(cfg)

global global_cfg, global_parser
global_cfg = cfg
global_parser = parser

return cfg
except ArgumentError:
logger.error('Config initialization failed')
Expand Down Expand Up @@ -371,7 +383,7 @@ def init_setup_from_cfg(cfg):
f'variable HF_DATASETS_CACHE.')
config.HF_DATASETS_CACHE = cfg.ds_cache_dir
else:
cfg.ds_cache_dir = config.HF_DATASETS_CACHE
cfg.ds_cache_dir = str(config.HF_DATASETS_CACHE)

# if there is suffix_filter op, turn on the add_suffix flag
cfg.add_suffix = False
Expand Down Expand Up @@ -478,6 +490,36 @@ def display_config(cfg):
print(table)


def export_config(cfg, path, format='yaml', skip_none=True, skip_check=True,
overwrite=False, multifile=True):
"""
save the config object, some params are from jsonargparse
:param cfg: cfg object to save (Namespace type)
:param path: the save path
:param format: 'yaml', 'json', 'json_indented', 'parser_mode'
:param skip_none: Whether to exclude entries whose value is None.
:param skip_check: Whether to skip parser checking.
:param overwrite: Whether to overwrite existing files.
:param multifile: Whether to save multiple config files
by using the __path__ metas.

:return:
"""
# remove ops outside the process list for better displaying
cfg_to_export = cfg.clone()
for op in OPERATORS.modules.keys():
_ = cfg_to_export.pop(op)

global global_parser
if not global_parser:
init_configs() # enable the customized type parser
global_parser.save(
cfg=cfg_to_export, path=path, format=format, skip_none=skip_none,
skip_check=skip_check, overwrite=overwrite, multifile=multifile)

logger.info(f'Saved the configuration in {path}')


def merge_config(ori_cfg, new_cfg: Dict):
"""
Merge configuration from new_cfg into ori_cfg
Expand Down
13 changes: 9 additions & 4 deletions data_juicer/core/analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def __init__(self, cfg=None):
self.overall_single_plot_path = None
self.analysis_path = os.path.join(self.cfg.work_dir, 'analysis')

def run(self, load_data_np=None):
def run(self, load_data_np=None, skip_export=False):
"""
Running the dataset analysis pipeline.

:param load_data_np: number of workers when loading the dataset.
:param skip_export: whether export the results into disk
:return: analysed dataset.
"""
# 1. format data
Expand Down Expand Up @@ -120,14 +121,18 @@ def run(self, load_data_np=None):

logger.info('Applying overall analysis on stats...')
overall_analysis = OverallAnalysis(dataset, self.analysis_path)
self.overall_result = overall_analysis.analyse(num_proc=self.cfg.np)
self.overall_result = overall_analysis.analyse(
num_proc=self.cfg.np, skip_export=skip_export)

logger.info(f'The overall analysis results are: {self.overall_result}')

logger.info('Applying column-wise analysis on stats...')
column_wise_analysis = ColumnWiseAnalysis(
dataset,
self.analysis_path,
overall_result=self.overall_result,
save_stats_in_one_file=self.cfg.save_stats_in_one_file)
column_wise_analysis.analyse()
save_stats_in_one_file=self.cfg.save_stats_in_one_file,
)
column_wise_analysis.analyse(skip_export=skip_export)

return dataset
96 changes: 95 additions & 1 deletion data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
import copy
import inspect
import io
import os

import zstandard as zstd
from loguru import logger

DEFAULT_PREFIX = '__dj__'


Expand All @@ -8,7 +16,89 @@ class Fields(object):
suffix = DEFAULT_PREFIX + 'suffix__'


class StatsKeys(object):
class StatsKeysMeta(type):
"""
a helper class to track the mapping from OP's name to its used stats_keys

e.g., # once the AlphanumericFilter's compute_stats method has been called
res = TrackingDescriptor.get_access_log()
print(res) # {"AlphanumericFilter": ["alnum_ratio", "alpha_token_ratio"]}
"""
_accessed_by = {}

def __getattr__(cls, attr):
caller_class = inspect.currentframe().f_back.f_globals['__name__']
# no need to track the parent classes
caller_class = caller_class.split('.')[-1]
stat_key = getattr(cls._constants_class, attr)
if caller_class not in cls._accessed_by:
cls._accessed_by[caller_class] = set()
if stat_key not in cls._accessed_by[caller_class]:
cls._accessed_by[caller_class].add(stat_key)
return stat_key

def get_access_log(cls, dj_cfg=None):
if cls._accessed_by:
return cls._accessed_by
elif dj_cfg:
tmp_dj_cfg = copy.deepcopy(dj_cfg)
# the access has been skipped due to the use of cache
# we will using a temp data sample to get the access log
if os.path.exists(dj_cfg.dataset_path) and \
('jsonl' in dj_cfg.dataset_path or
'jsonl.zst' in dj_cfg.dataset_path):
logger.info(
'Begin to track the usage of ops with a dummy data sample')

# load the first line as tmp_data
tmp_f_name = None
first_line = None
if 'jsonl.zst' in dj_cfg.dataset_path:
tmp_f_name = dj_cfg.dataset_path. \
replace('.jsonl.zst', '.tmp.jsonl')
# Open the file in binary mode and
# create a Zstandard decompression context
with open(dj_cfg.dataset_path, 'rb') as compressed_file:
dctx = zstd.ZstdDecompressor()
# Create a stream reader for the file and decode the
# first line
with dctx.stream_reader(compressed_file) as reader:
text_stream = io.TextIOWrapper(
reader, encoding='utf-8')
first_line = text_stream.readline()
elif 'jsonl' in dj_cfg.dataset_path:
tmp_f_name = dj_cfg.dataset_path. \
replace('.jsonl', '.tmp.jsonl')
with open(dj_cfg.dataset_path, 'r') as orig_file:
first_line = orig_file.readline()

assert tmp_f_name is not None and first_line is not None, \
'error when loading the first line, when ' \
f'dj_cfg.dataset_path={dj_cfg.dataset_path}'

with open(tmp_f_name, 'w') as tmp_file:
tmp_file.write(first_line)

tmp_dj_cfg.dataset_path = tmp_f_name
tmp_dj_cfg.use_cache = False
tmp_dj_cfg.use_checkpoint = False

from data_juicer.core import Analyser
tmp_analyzer = Analyser(tmp_dj_cfg)
# do not overwrite the true analysis results
tmp_analyzer.run(skip_export=True)

os.remove(tmp_f_name)
else:
raise NotImplementedError(
f'For now, the dummy data is supported for only jsonl type'
f'. Please check your config as {dj_cfg.dataset_path} is '
f'either not existed or in jsonl type.')

return cls._accessed_by


class StatsKeysConstant(object):
# text
alpha_token_ratio = 'alpha_token_ratio'
alnum_ratio = 'alnum_ratio'
Expand Down Expand Up @@ -41,6 +131,10 @@ class StatsKeys(object):
image_text_matching_score = 'image_text_matching_score'


class StatsKeys(object, metaclass=StatsKeysMeta):
_constants_class = StatsKeysConstant


class HashKeys(object):
hash = DEFAULT_PREFIX + 'hash'
minhash = DEFAULT_PREFIX + 'minhash'
Expand Down
34 changes: 28 additions & 6 deletions tools/hpo/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
# Hyper-parameter Optimization for Data Recipe

## Auto-HPO
## Auto-HPO based on 3-Sigma principles
A simple automatic hyper-parameter optimization method for data recipes is to assume that outlier data is harmful to training.
We thus can introduce the 3-sigma principle to automatically determine the hyper-parameters and filter the data.

Specifically, assuming that a certain analysis dimension of the original data obeys a normal distribution and has random errors, we can set the upper and lower bounds of the filtering OP in this dimension to three times the standard deviation based on the statistics produced by the DataJuicer's Analyzer.

$$P(|x-\mu| > 3\sigma) \leq 0.003$$

To automate this process, we provide the tool which can be used as follows:
```shell
# cd tools/hpo
# usage 1: do not save the refined recipe
python execute_hpo_3sigma.py --config <data-process-cfg-file-path>
# usage 2: save the refined recipe at the given path
python execute_hpo_3sigma.py --config <data-process-cfg-file-path> --path_3sigma_recipe <data-process-cfg-file-after-refined-path>

# e.g., usage 1
python execute_hpo_3sigma.py --config configs/process.yaml
# e.g., usage 2
python execute_hpo_3sigma.py --config configs/process.yaml --path_3sigma_recipe configs/process_3sigma.yaml
```
yxdyc marked this conversation as resolved.
Show resolved Hide resolved

## Auto-HPO with WandB

We incorporate an automated HPO tool, WandB [Sweep](https://docs.wandb.ai/guides/sweeps), into Data-Juicer to streamline the finding of good data processing hyper-parameters.
With this tool, users can investigate correlations and importance scores of
Expand All @@ -11,7 +33,7 @@ a large room to explore. Feel free to provide more suggestions, discussion,
and contribution via new PRs!


## Prerequisite
### Prerequisite
You need to install data-juicer first.
Besides, the tool leverages WandB, install it via `pip install wandb`.
Before using this tool, you need to run
Expand All @@ -26,17 +48,17 @@ wandb login --host <URL of your wandb instance>



## Usage and Customization
### Usage and Customization

Given a data recipe, characterized by specified configuration file
`<data-process-cfg-file-path>`, you can use `execute_hpo.py` to search the
`<data-process-cfg-file-path>`, you can use `execute_hpo_wandb.py` to search the
hyper-parameter space defined by `<hpo-cfg-file-path>`.
```shell
# cd tools/hpo
python execute_hpo.py --config <data-process-cfg-file-path> --hpo_config <hpo-cfg-file-path>
python execute_hpo_wandb.py --config <data-process-cfg-file-path> --hpo_config <hpo-cfg-file-path>

# e.g.,
python execute_hpo.py --config configs/process.yaml --hpo_config configs/quality_score_hpo.yaml
python execute_hpo_wandb.py --config configs/process.yaml --hpo_config configs/quality_score_hpo.yaml
```

For the configuration for data recipe, i.e., `<data-process-cfg-file-path>`,
Expand Down
Loading