Skip to content

Commit

Permalink
[Feature] add auto mode for analyzer (#512)
Browse files Browse the repository at this point in the history
* + add auto mode for analyzer: load all filters that produce stats to analyze the target dataset

* + add default mem_required for those model-based OPs

* - support wordcloud drawing for str or str list fields in stats
- support set the number of samples to be analyzed in auto mode. It's 1k in default.

* - take the minimum one of dataset length and auto num

* * update default export path

* * set version limit for wandb to avoid exception

* + add docs for auto mode
  • Loading branch information
HYLcool authored Dec 20, 2024
1 parent 8583333 commit 2fdf484
Show file tree
Hide file tree
Showing 30 changed files with 143 additions and 29 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ python tools/analyze_data.py --config configs/demo/analyzer.yaml
# use command line tool
dj-analyze --config configs/demo/analyzer.yaml
# you can also use auto mode to avoid writing a recipe. It will analyze a small
# part (e.g. 1000 samples, specified by argument `auto_num`) of your dataset
# with all Filters that produce stats.
dj-analyze --auto --dataset_path xx.jsonl [--auto_num 1000]
```
- **Note:** Analyzer only compute stats of Filter ops. So extra Mapper or Deduplicator ops will be ignored in the analysis process.
Expand Down
4 changes: 4 additions & 0 deletions README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ python tools/analyze_data.py --config configs/demo/analyzer.yaml
# 使用命令行工具
dj-analyze --config configs/demo/analyzer.yaml
# 你也可以使用"自动"模式来避免写一个新的数据菜谱。它会使用全部可产出统计信息的 Filter 来分析
# 你的数据集的一小部分(如1000条样本,可通过 `auto_num` 参数指定)
dj-analyze --auto --dataset_path xx.jsonl [--auto_num 1000]
```

* **注意**:Analyzer 只计算 Filter 算子的状态,其他的算子(例如 Mapper 和 Deduplicator)会在分析过程中被忽略。
Expand Down
2 changes: 1 addition & 1 deletion configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ process:
vertical_flip: false # flip frame image vertically (top to bottom).
reduce_mode: avg # reduce mode when one text corresponds to multiple videos in a chunk, must be one of ['avg','max', 'min'].
any_or_all: any # keep this sample when any/all videos meet the filter condition
mem_required: '1GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched
mem_required: '1500MB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched
- video_motion_score_filter: # Keep samples with video motion scores within a specific range.
min_score: 0.25 # the minimum motion score to keep samples
max_score: 10000.0 # the maximum motion score to keep samples
Expand Down
69 changes: 53 additions & 16 deletions data_juicer/analysis/column_wise_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from wordcloud import WordCloud

from data_juicer.utils.constant import Fields

Expand Down Expand Up @@ -145,33 +146,39 @@ def analyze(self, show_percentiles=False, show=False, skip_export=False):
else:
axes = [None] * num_subcol

# draw histogram
self.draw_hist(axes[0],
data,
os.path.join(self.output_path,
f'{column_name}-hist.png'),
percentiles=percentiles)

# draw box
self.draw_box(axes[1],
data,
os.path.join(self.output_path,
f'{column_name}-box.png'),
percentiles=percentiles)
if not skip_export:
# draw histogram
self.draw_hist(axes[0],
data,
os.path.join(self.output_path,
f'{column_name}-hist.png'),
percentiles=percentiles)

# draw box
self.draw_box(axes[1],
data,
os.path.join(self.output_path,
f'{column_name}-box.png'),
percentiles=percentiles)
else:
# object (string) or string list -- only draw histogram for
# this stat
if self.save_stats_in_one_file:
axes = subfig.subplots(1, 1)
axes = subfig.subplots(1, num_subcol)
else:
axes = None
axes = [None] * num_subcol

if not skip_export:
self.draw_hist(
axes, data,
axes[0], data,
os.path.join(self.output_path,
f'{column_name}-hist.png'))

self.draw_wordcloud(
axes[1], data,
os.path.join(self.output_path,
f'{column_name}-wordcloud.png'))

# add a title to the figure of this stat
if self.save_stats_in_one_file:
subfig.suptitle(f'{data.name}',
Expand Down Expand Up @@ -297,3 +304,33 @@ def draw_box(self, ax, data, save_path, percentiles=None, show=False):
# accumulated overlapped figures in different draw_xxx function
# calling
ax.clear()

def draw_wordcloud(self, ax, data, save_path, show=False):
word_list = data.tolist()
word_nums = {}
for w in word_list:
if w in word_nums:
word_nums[w] += 1
else:
word_nums[w] = 1

wc = WordCloud(width=400, height=320)
wc.generate_from_frequencies(word_nums)

if ax is None:
ax = plt.figure(figsize=(20, 16))
else:
ax.imshow(wc, interpolation='bilinear')
ax.axis('off')

if not self.save_stats_in_one_file:
# save into file
wc.to_file(save_path)

if show:
plt.show()
else:
# if no showing, we need to clear this axes to avoid
# accumulated overlapped figures in different draw_xxx function
# calling
ax.clear()
52 changes: 45 additions & 7 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
global_parser = None


def init_configs(args: Optional[List[str]] = None):
def init_configs(args: Optional[List[str]] = None, which_entry: object = None):
"""
initialize the jsonargparse parser and parse configs from one of:
1. POSIX-style commands line args;
Expand All @@ -32,14 +32,29 @@ def init_configs(args: Optional[List[str]] = None):
4. hard-coded defaults
:param args: list of params, e.g., ['--conifg', 'cfg.yaml'], defaut None.
:param which_entry: which entry to init configs (executor/analyzer)
:return: a global cfg object used by the Executor or Analyzer
"""
parser = ArgumentParser(default_env=True, default_config_files=None)

parser.add_argument('--config',
action=ActionConfigFile,
help='Path to a dj basic configuration file.',
required=True)
# required but mutually exclusive args group
required_group = parser.add_mutually_exclusive_group(required=True)
required_group.add_argument('--config',
action=ActionConfigFile,
help='Path to a dj basic configuration file.')
required_group.add_argument('--auto',
action='store_true',
help='Weather to use an auto analyzing '
'strategy instead of a specific data '
'recipe. If a specific config file is '
'given by --config arg, this arg is '
'disabled. Only available for Analyzer.')

parser.add_argument('--auto_num',
type=PositiveInt,
default=1000,
help='The number of samples to be analyzed '
'automatically. It\'s 1000 in default.')

parser.add_argument(
'--hpo_config',
Expand Down Expand Up @@ -97,7 +112,7 @@ def init_configs(args: Optional[List[str]] = None):
parser.add_argument(
'--export_path',
type=str,
default='./outputs/hello_world.jsonl',
default='./outputs/hello_world/hello_world.jsonl',
help='Path to export and save the output processed dataset. The '
'directory to store the processed dataset will be the work '
'directory of this process.')
Expand Down Expand Up @@ -339,6 +354,14 @@ def init_configs(args: Optional[List[str]] = None):

try:
cfg = parser.parse_args(args=args)

# check the entry
from data_juicer.core.analyzer import Analyzer
if not isinstance(which_entry, Analyzer) and cfg.auto:
err_msg = '--auto argument can only be used for analyzer!'
logger.error(err_msg)
raise NotImplementedError(err_msg)

cfg = init_setup_from_cfg(cfg)
cfg = update_op_process(cfg, parser)

Expand Down Expand Up @@ -488,6 +511,16 @@ def init_setup_from_cfg(cfg: Namespace):
SpecialTokens.image = cfg.image_special_token
SpecialTokens.eoc = cfg.eoc_special_token

# add all filters that produce stats
if cfg.auto:
import pkgutil

import data_juicer.ops.filter as djfilters
cfg.process = [{
filter_name: {}
} for _, filter_name, _ in pkgutil.iter_modules(djfilters.__path__)
if filter_name not in djfilters.NON_STATS_FILTERS]

# Apply text_key modification during initializing configs
# users can freely specify text_key for different ops using `text_key`
# otherwise, set arg text_key of each op to text_keys
Expand Down Expand Up @@ -636,7 +669,10 @@ def update_op_process(cfg, parser):
temp_args = namespace_to_arg_list(temp_cfg,
includes=recognized_args,
excludes=['config'])
temp_args = ['--config', temp_cfg.config[0].absolute] + temp_args
if temp_cfg.config:
temp_args = ['--config', temp_cfg.config[0].absolute] + temp_args
else:
temp_args = ['--auto'] + temp_args
temp_parser.parse_args(temp_args)
return cfg

Expand All @@ -662,6 +698,8 @@ def namespace_to_arg_list(namespace, prefix='', includes=None, excludes=None):


def config_backup(cfg: Namespace):
if not cfg.config:
return
cfg_path = cfg.config[0].absolute
work_dir = cfg.work_dir
target_path = os.path.join(work_dir, os.path.basename(cfg_path))
Expand Down
6 changes: 5 additions & 1 deletion data_juicer/core/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, cfg: Optional[Namespace] = None):
:param cfg: optional jsonargparse Namespace dict.
"""
self.cfg = init_configs() if cfg is None else cfg
self.cfg = init_configs(which_entry=self) if cfg is None else cfg

self.work_dir = self.cfg.work_dir

Expand Down Expand Up @@ -87,6 +87,10 @@ def run(self,
if load_data_np is None:
load_data_np = self.cfg.np
dataset = self.formatter.load_dataset(load_data_np, self.cfg)
if self.cfg.auto:
# if it's auto analysis, only analyze for a minor part of the input
# dataset to save time and computing resource
dataset = dataset.take(min(len(dataset), self.cfg.auto_num))

# extract processes
logger.info('Preparing process operators...')
Expand Down
7 changes: 7 additions & 0 deletions data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,10 @@
'VideoTaggingFromFramesFilter', 'VideoWatermarkFilter',
'WordRepetitionFilter', 'WordsNumFilter'
]

NON_STATS_FILTERS = [
'specified_field_filter',
'specified_numeric_field_filter',
'suffix_filter',
'video_tagging_from_frames_filter',
]
2 changes: 1 addition & 1 deletion data_juicer/ops/filter/image_aesthetics_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self,
:param args: Extra positional arguments.
:param kwargs: Extra keyword arguments.
"""

kwargs.setdefault('mem_required', '1500MB')
super().__init__(*args, **kwargs)
if hf_scorer_model == '':
hf_scorer_model = \
Expand Down
1 change: 1 addition & 0 deletions data_juicer/ops/filter/image_nsfw_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '1GB')
super().__init__(*args, **kwargs)
self.score_threshold = score_threshold
if any_or_all not in ['any', 'all']:
Expand Down
1 change: 1 addition & 0 deletions data_juicer/ops/filter/image_text_matching_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '1500MB')
super().__init__(*args, **kwargs)
self.min_score = min_score
self.max_score = max_score
Expand Down
1 change: 1 addition & 0 deletions data_juicer/ops/filter/image_text_similarity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '1500MB')
super().__init__(*args, **kwargs)
self.min_score = min_score
self.max_score = max_score
Expand Down
1 change: 1 addition & 0 deletions data_juicer/ops/filter/image_watermark_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '500MB')
super().__init__(*args, **kwargs)
self.prob_threshold = prob_threshold
if any_or_all not in ['any', 'all']:
Expand Down
1 change: 1 addition & 0 deletions data_juicer/ops/filter/phrase_grounding_recall_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '1GB')
super().__init__(*args, **kwargs)
self.min_recall = min_recall
self.max_recall = max_recall
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/filter/video_aesthetics_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self,
:param args: Extra positional arguments.
:param kwargs: Extra keyword arguments.
"""

kwargs.setdefault('mem_required', '1500MB')
super().__init__(*args, **kwargs)
if hf_scorer_model == '':
hf_scorer_model = \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '1500MB')
super().__init__(*args, **kwargs)
self.min_score = min_score
self.max_score = max_score
Expand Down
1 change: 1 addition & 0 deletions data_juicer/ops/filter/video_nsfw_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '1GB')
super().__init__(*args, **kwargs)
self.score_threshold = score_threshold
if frame_sampling_method not in ['all_keyframes', 'uniform']:
Expand Down
1 change: 1 addition & 0 deletions data_juicer/ops/filter/video_tagging_from_frames_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '9GB')
super().__init__(*args, **kwargs)
if contain not in ['any', 'all']:
raise ValueError(f'the containing type [{contain}] is not '
Expand Down
1 change: 1 addition & 0 deletions data_juicer/ops/filter/video_watermark_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '500MB')
super().__init__(*args, **kwargs)
self.prob_threshold = prob_threshold
if frame_sampling_method not in ['all_keyframes', 'uniform']:
Expand Down
2 changes: 2 additions & 0 deletions data_juicer/ops/mapper/image_captioning_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '16GB')

super().__init__(*args, **kwargs)

if keep_candidate_mode not in [
Expand Down
1 change: 1 addition & 0 deletions data_juicer/ops/mapper/image_diffusion_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self,
:param hf_img2seq: model name on huggingface to generate caption if
caption_key is None.
"""
kwargs.setdefault('mem_required', '8GB')
super().__init__(*args, **kwargs)
self._init_parameters = self.remove_extra_parameters(locals())
self.strength = strength
Expand Down
1 change: 1 addition & 0 deletions data_juicer/ops/mapper/image_tagging_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '9GB')
super().__init__(*args, **kwargs)
self.model_key = prepare_model(
model_type='recognizeAnything',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, keep_original_sample: bool = True, *args, **kwargs):
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '30GB')
super().__init__(*args, **kwargs)
AUTOINSTALL.check([
'transformers', 'transformers_stream_generator', 'einops',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '20GB')
super().__init__(*args, **kwargs)

if keep_candidate_mode not in [
Expand Down
Loading

0 comments on commit 2fdf484

Please sign in to comment.