From b6f89a90ce86978054e3445155223425eca31efb Mon Sep 17 00:00:00 2001 From: Yilun Huang Date: Fri, 20 Dec 2024 19:09:32 +0800 Subject: [PATCH] [Feat] OP-wise Insight Mining (#516) * + add auto mode for analyzer: load all filters that produce stats to analyze the target dataset * + add default mem_required for those model-based OPs * - support wordcloud drawing for str or str list fields in stats - support set the number of samples to be analyzed in auto mode. It's 1k in default. * - take the minimum one of dataset length and auto num * * update default export path * * set version limit for wandb to avoid exception * + add docs for auto mode * + support t-test for Measure * * fix some bugs * - support analyze a dataset object - optimize the logics of loading filters that produce stats and updating attributes of OPs * - support analysis on tags in meta * - support analysis with tagging OPs * - move tags into the meta field * - do not tell tags using their suffix - suppress the error/exceptions in Monitor due to the termination of the main process - exported stats file includes meta field in exporter * - add insight mining * * resolve the bugs when running insight mining in multiprocessing mode * * update unittests * * update unittests * * update unittests * * update readme for analyzer * * use more detailed key * + add reference --- README.md | 4 +- README_ZH.md | 4 +- data_juicer/analysis/column_wise_analysis.py | 24 +- data_juicer/analysis/measure.py | 111 +++++++++ data_juicer/analysis/overall_analysis.py | 16 +- data_juicer/config/config.py | 84 ++++--- data_juicer/core/adapter.py | 125 +++++++++- data_juicer/core/analyzer.py | 35 ++- data_juicer/core/data.py | 30 ++- data_juicer/core/executor.py | 1 + data_juicer/core/exporter.py | 11 +- data_juicer/core/monitor.py | 8 +- data_juicer/ops/__init__.py | 5 +- data_juicer/ops/base_op.py | 18 +- .../ops/filter/specified_field_filter.py | 7 +- .../filter/specified_numeric_field_filter.py | 8 +- data_juicer/ops/filter/suffix_filter.py | 7 +- .../video_tagging_from_frames_filter.py | 7 +- .../ops/mapper/image_tagging_mapper.py | 10 +- .../mapper/video_tagging_from_audio_mapper.py | 11 +- .../video_tagging_from_frames_mapper.py | 10 +- data_juicer/utils/cache_utils.py | 47 ++++ data_juicer/utils/constant.py | 6 +- .../test_video_tagging_from_frames_filter.py | 8 +- tests/ops/mapper/test_image_tagging_mapper.py | 102 ++++---- .../test_video_tagging_from_audio_mapper.py | 5 +- .../test_video_tagging_from_frames_mapper.py | 220 ++++++++++-------- 27 files changed, 680 insertions(+), 244 deletions(-) diff --git a/README.md b/README.md index d891ac332..586869b0a 100644 --- a/README.md +++ b/README.md @@ -340,7 +340,9 @@ dj-analyze --config configs/demo/analyzer.yaml dj-analyze --auto --dataset_path xx.jsonl [--auto_num 1000] ``` -- **Note:** Analyzer only compute stats of Filter ops. So extra Mapper or Deduplicator ops will be ignored in the analysis process. +- **Note:** Analyzer only compute stats for Filters that produce stats or other OPs that produce tags/categories in meta. So other OPs will be ignored in the analysis process. We use the following registries to decorate OPs: + - `NON_STATS_FILTERS`: decorate Filters that **DO NOT** produce any stats. + - `TAGGING_OPS`: decorate OPs that **DO** produce tags/categories in meta field. ### Data Visualization diff --git a/README_ZH.md b/README_ZH.md index 01633731b..42612964a 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -316,7 +316,9 @@ dj-analyze --config configs/demo/analyzer.yaml dj-analyze --auto --dataset_path xx.jsonl [--auto_num 1000] ``` -* **注意**:Analyzer 只计算 Filter 算子的状态,其他的算子(例如 Mapper 和 Deduplicator)会在分析过程中被忽略。 +* **注意**:Analyzer 只用于能在 stats 字段里产出统计信息的 Filter 算子和能在 meta 字段里产出 tags 或类别标签的其他算子。除此之外的其他的算子会在分析过程中被忽略。我们使用以下两种注册器来装饰相关的算子: + * `NON_STATS_FILTERS`:装饰那些**不能**产出任何统计信息的 Filter 算子。 + * `TAGGING_OPS`:装饰那些能在 meta 字段中产出 tags 或类别标签的算子。 ### 数据可视化 diff --git a/data_juicer/analysis/column_wise_analysis.py b/data_juicer/analysis/column_wise_analysis.py index 825d9b4dd..ce5b3617d 100644 --- a/data_juicer/analysis/column_wise_analysis.py +++ b/data_juicer/analysis/column_wise_analysis.py @@ -6,7 +6,7 @@ from tqdm import tqdm from wordcloud import WordCloud -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import DEFAULT_PREFIX, Fields from .overall_analysis import OverallAnalysis @@ -70,6 +70,12 @@ def __init__(self, stats into one image file """ self.stats = pd.DataFrame(dataset[Fields.stats]) + self.meta = pd.DataFrame(dataset[Fields.meta]) + # remove non-tag columns + meta_columns = self.meta.columns + for col_name in meta_columns: + if not col_name.startswith(DEFAULT_PREFIX): + self.meta = self.meta.drop(col_name, axis=1) self.output_path = output_path if not os.path.exists(self.output_path): os.makedirs(self.output_path) @@ -101,8 +107,9 @@ def analyze(self, show_percentiles=False, show=False, skip_export=False): width_unit = 4 height_unit = 6 - columns = self.stats.columns - num = len(columns) + stats_and_meta = pd.concat([self.stats, self.meta], axis=1) + all_columns = stats_and_meta.columns + num = len(all_columns) # get the recommended "best" number of columns and rows rec_row, rec_col, grid_indexes = get_row_col(num, num_subcol) @@ -115,9 +122,9 @@ def analyze(self, show_percentiles=False, show=False, skip_export=False): fig = plt.figure(figsize=(rec_width, rec_height), layout='constrained') subfigs = fig.subfigures(rec_row, rec_col, wspace=0.01) - for i, column_name in enumerate(tqdm(columns.to_list(), - desc='Column')): - data = self.stats[column_name] + for i, column_name in enumerate( + tqdm(all_columns.to_list(), desc='Column')): + data = stats_and_meta[column_name] # explode data to flatten inner list data = data.explode().infer_objects() grid = grid_indexes[i] @@ -210,10 +217,7 @@ def draw_hist(self, ax, data, save_path, percentiles=None, show=False): """ # recommended number of bins data_num = len(data) - if data_num >= 100: - rec_bins = int(math.sqrt(len(data))) - else: - rec_bins = None + rec_bins = max(int(math.sqrt(data_num)), 10) # if ax is None, using plot method in pandas if ax is None: diff --git a/data_juicer/analysis/measure.py b/data_juicer/analysis/measure.py index fe54cdabd..bd97e811c 100644 --- a/data_juicer/analysis/measure.py +++ b/data_juicer/analysis/measure.py @@ -1,9 +1,13 @@ +import numpy as np + from data_juicer.utils.lazy_loader import LazyLoader torch = LazyLoader('torch', 'torch') td = LazyLoader('td', 'torch.distributions') F = LazyLoader('F', 'torch.nn.functional') +stats = LazyLoader('stats', 'scipy.stats') + class Measure(object): """Base class for Measure distribution. @@ -48,6 +52,15 @@ def _convert_to_categorical(self, p): else: return td.Categorical(torch.tensor(p)) + def _convert_to_ndarray(self, p): + """ + Convert input data to torch tensor. + :param p: input data, now support + [`scalar`,`list`, `tuple`, `torch binary file`, and `Categorical`]. + :return: torch tensor + """ + return self._convert_to_tensor(p).numpy() + class KLDivMeasure(Measure): """ @@ -108,3 +121,101 @@ class EntropyMeasure(Measure): def measure(self, p): p = self._convert_to_categorical(p) return p.entropy() + + +class RelatedTTestMeasure(Measure): + """ + Measure T-Test for two related distributions on their histogram of the same + bins. + + Ref: + https://en.wikipedia.org/wiki/Student%27s_t-test + + For continuous features or distributions, the input could be dataset stats + list. + For discrete features or distributions, the input could be the tags or the + categories list. + """ + name = 't-test' + + @staticmethod + def stats_to_hist(p, q): + p = np.array(p) + q = np.array(q) + + # get common maximum number of data samples, and max/min values + max_data_num = max(len(p), len(q)) + min_val = min(min(p), min(q)) + max_val = max(max(p), max(q)) + + # get a recommended number of bins + rec_bins = max(int(np.sqrt(max_data_num)), 10) + + # get the common bin edges + common_p = np.append(p, [min_val, max_val]) + hist_p, bin_edges = np.histogram(common_p, bins=rec_bins) + # restore the hist of the original p + hist_p[0] -= 1 + hist_p[-1] -= 1 + # get the hist of the original q using the common bin edges + hist_q, _ = np.histogram(q, bins=bin_edges) + return hist_p, hist_q, bin_edges + + @staticmethod + def category_to_hist(p, q): + + def flatten_list(lst): + res = [] + for s in lst: + if isinstance(s, list): + res.extend(flatten_list(s)) + else: + res.append(s) + return res + + # flatten the list + p = flatten_list(p) + q = flatten_list(q) + + # get the common categories + cat_p = set(p) + cat_q = set(q) + cat_common = cat_p.union(cat_q) + + # get category distributions + count_p = {cat: 0 for cat in cat_common} + count_q = {cat: 0 for cat in cat_common} + for cat in p: + count_p[cat] += 1 + for cat in q: + count_q[cat] += 1 + + # only keep distribution values sorted by counts + sorted_cat = list(count_p.items()) + sorted_cat.sort(key=lambda it: it[1], reverse=True) + sorted_cat = [it[0] for it in sorted_cat] + # get the value dist + hist_p = [count_p[cat] for cat in sorted_cat] + hist_q = [count_q[cat] for cat in sorted_cat] + + return hist_p, hist_q, count_p, count_q, sorted_cat + + def measure(self, p, q): + """ + :param p: the first feature or distribution. (stats/tags/categories) + :param q: the second feature or distribution. (stats/tags/categories) + :return: the T-Test results object -- ([ref](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats._result_classes.TtestResult.html#scipy.stats._result_classes.TtestResult)) # noqa: E501 + """ + ele = p[0] + while isinstance(ele, list): + ele = ele[0] + if isinstance(ele, str): + # discrete tags or categories + hist_p, hist_q = self.category_to_hist(p, q)[:2] + else: + # continuous stats + hist_p, hist_q = self.stats_to_hist(p, q)[:2] + + # compute the t-test and pval for hist_p and hist_q + ttest_res = stats.ttest_rel(hist_p, hist_q) + return ttest_res diff --git a/data_juicer/analysis/overall_analysis.py b/data_juicer/analysis/overall_analysis.py index 04eefb178..696b25946 100644 --- a/data_juicer/analysis/overall_analysis.py +++ b/data_juicer/analysis/overall_analysis.py @@ -5,7 +5,7 @@ from loguru import logger from tqdm import tqdm -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import DEFAULT_PREFIX, Fields def _single_column_analysis(col, *args, **kwargs): @@ -25,6 +25,12 @@ def __init__(self, dataset, output_path): :param output_path: path to store the analysis results. """ self.stats = pd.DataFrame(dataset[Fields.stats]) + self.meta = pd.DataFrame(dataset[Fields.meta]) + # remove non-tag columns + meta_columns = self.meta.columns + for col_name in meta_columns: + if not col_name.startswith(DEFAULT_PREFIX): + self.meta = self.meta.drop(col_name, axis=1) self.output_path = output_path if not os.path.exists(self.output_path): os.makedirs(self.output_path) @@ -71,10 +77,14 @@ def analyze(self, percentiles=[], num_proc=1, skip_export=False): # merge default and customized percentiles and get overall information percentiles = list(set(percentiles + self.default_percentiles)) + # merge stats and meta + stats_and_meta = pd.concat([self.stats, self.meta], axis=1) + all_columns = stats_and_meta.columns + results = [] pool = Pool(num_proc) - for col_name in self.stats.columns: - this_col = self.refine_single_column(self.stats[col_name]) + for col_name in all_columns: + this_col = self.refine_single_column(stats_and_meta[col_name]) res = pool.apply_async(_single_column_analysis, kwds={ 'col': this_col, diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index c7f0aaf38..028f3cf79 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -290,6 +290,22 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None): help='Number of samples extracted by tracer to show the dataset ' 'difference before and after a op. Only available when ' 'open_tracer is true.') + parser.add_argument( + '--open_insight_mining', + type=bool, + default=False, + help='Whether to open insight mining to trace the OP-wise stats/tags ' + 'changes during process. It might take more time when opening ' + 'insight mining.') + parser.add_argument( + '--op_list_to_mine', + type=List[str], + default=[], + help='Which OPs will be applied on the dataset to mine the insights ' + 'in their stats changes. Only those OPs that produce stats or ' + 'meta are valid. If it\'s empty, all OPs that produce stats and ' + 'meta will be involved. Only available when filter_list_to_mine ' + 'is true.') parser.add_argument( '--op_fusion', type=bool, @@ -513,13 +529,7 @@ def init_setup_from_cfg(cfg: Namespace): # add all filters that produce stats if cfg.auto: - import pkgutil - - import data_juicer.ops.filter as djfilters - cfg.process = [{ - filter_name: {} - } for _, filter_name, _ in pkgutil.iter_modules(djfilters.__path__) - if filter_name not in djfilters.NON_STATS_FILTERS] + cfg.process = load_ops_with_stats_meta() # Apply text_key modification during initializing configs # users can freely specify text_key for different ops using `text_key` @@ -528,34 +538,48 @@ def init_setup_from_cfg(cfg: Namespace): text_key = cfg.text_keys[0] else: text_key = cfg.text_keys - for op in cfg.process: + op_attrs = { + 'text_key': text_key, + 'image_key': cfg.image_key, + 'audio_key': cfg.audio_key, + 'video_key': cfg.video_key, + 'num_proc': cfg.np, + 'turbo': cfg.turbo, + } + cfg.process = update_op_attr(cfg.process, op_attrs) + + return cfg + + +def load_ops_with_stats_meta(): + import pkgutil + + import data_juicer.ops.filter as djfilter + from data_juicer.ops import NON_STATS_FILTERS, TAGGING_OPS + stats_filters = [{ + filter_name: {} + } for _, filter_name, _ in pkgutil.iter_modules(djfilter.__path__) + if filter_name not in NON_STATS_FILTERS.modules] + meta_ops = [{op_name: {}} for op_name in TAGGING_OPS.modules] + return stats_filters + meta_ops + + +def update_op_attr(op_list: list, attr_dict: dict = None): + if not attr_dict: + return op_list + updated_op_list = [] + for op in op_list: for op_name in op: args = op[op_name] if args is None: - args = { - 'text_key': text_key, - 'image_key': cfg.image_key, - 'audio_key': cfg.audio_key, - 'video_key': cfg.video_key, - 'num_proc': cfg.np, - 'turbo': cfg.turbo, - } + args = attr_dict else: - if 'text_key' not in args or args['text_key'] is None: - args['text_key'] = text_key - if 'image_key' not in args or args['image_key'] is None: - args['image_key'] = cfg.image_key - if 'audio_key' not in args or args['audio_key'] is None: - args['audio_key'] = cfg.audio_key - if 'video_key' not in args or args['video_key'] is None: - args['video_key'] = cfg.video_key - if 'num_proc' not in args or args['num_proc'] is None: - args['num_proc'] = cfg.np - if 'turbo' not in args or args['turbo'] is None: - args['turbo'] = cfg.turbo + for key in attr_dict: + if key not in args or args[key] is None: + args[key] = attr_dict[key] op[op_name] = args - - return cfg + updated_op_list.append(op) + return updated_op_list def _collect_config_info_from_class_docs(configurable_ops, parser): diff --git a/data_juicer/core/adapter.py b/data_juicer/core/adapter.py index 5ab6e6ec8..64fd622f0 100644 --- a/data_juicer/core/adapter.py +++ b/data_juicer/core/adapter.py @@ -1,8 +1,15 @@ -from datasets import concatenate_datasets +import json +import os +from copy import deepcopy + +from datasets import Dataset, concatenate_datasets from datasets.config import DEFAULT_MAX_BATCH_SIZE +from data_juicer.analysis.measure import RelatedTTestMeasure from data_juicer.core.monitor import Monitor from data_juicer.ops import UNFORKABLE +from data_juicer.utils.cache_utils import dataset_cache_control +from data_juicer.utils.constant import Fields from data_juicer.utils.process_utils import setup_mp @@ -12,6 +19,11 @@ class Adapter: def __init__(self, cfg: dict): self.cfg = cfg + + # insight mining related + self.enable_insight_mining = self.cfg.open_insight_mining + + # resource probe related self.idle_resources = Monitor.monitor_current_resources() @staticmethod @@ -108,25 +120,21 @@ def adapt_workloads(self, dataset, operators): return bs_per_op + @dataset_cache_control(on=True) def probe_small_batch(self, dataset, operators): """ Perform small batch pre-execution to probe available resources, current load and estimated OP speed, returning load factors and speed ranks for each OP. - Notice: the probe should be run with cache enabled. + Notice: the probe should be run with cache enabled to avoid removing + the cache files of the input dataset. :param dataset: The dataset to pre-execute small batch on :param operators: The OP list to be pre-execution and probe :return: A list of probe results for each OP and the length of data batch to probe. """ - # record the cache state and enable the cache - from datasets import (disable_caching, enable_caching, - is_caching_enabled) - previous_state = is_caching_enabled() - if not previous_state: - enable_caching() # take a small batch data_batch = self.take_batch(dataset, self.cfg) @@ -135,10 +143,6 @@ def probe_small_batch(self, dataset, operators): # analyze resource utilization analysis_res = Monitor.analyze_resource_util_list(resource_util_list) - # if the cache is disabled before, disable it again - if not previous_state: - disable_caching() - return analysis_res, len(data_batch) def batch_size_strategy(self, load_analysis_res, base_bs=1, util_th=0.9): @@ -177,3 +181,100 @@ def batch_size_strategy(self, load_analysis_res, base_bs=1, util_th=0.9): batch_size_per_op.append(bs_this_op) return batch_size_per_op + + @dataset_cache_control(on=True) + def analyze_small_batch(self, dataset, current_state): + """ + Perform small batch analysis to probe the current OP-wise stats/meta + distributions. The analyzed results will be stored in the directory + `{work_dir}/insight_mining`. + + Notice: the probe should be run with cache enabled to avoid removing + the cache files of the input dataset. + + :param dataset: The dataset to analyze small batch on + :param current_state: A string to indicate the current state of the + input dataset. It usually consists of a number of the index of the + OP processed just now and the OP name, e.g. "1_text_length_filter". + """ + # prepare analyzer config + new_cfg = deepcopy(self.cfg) + # check ops to mine + new_cfg.auto = True + new_cfg.config = None + if len(new_cfg.op_list_to_mine) > 0: + new_cfg.process = [{ + op_name: {} + } for op_name in new_cfg.op_list_to_mine] + # update work dir + new_cfg.work_dir = os.path.join(new_cfg.work_dir, 'insight_mining', + current_state) + new_cfg.export_path = os.path.join(new_cfg.work_dir, + f'{current_state}.jsonl') + # close insight mining and monitor for inner analysis + new_cfg.open_insight_mining = False + new_cfg.open_monitor = False + + # init the analyzer + from data_juicer.core.analyzer import Analyzer + analyzer = Analyzer(new_cfg) + + # remove existing stats and meta in dataset + target_fields = {Fields.stats, Fields.meta} + target_fields = target_fields.intersection(set(dataset.features)) + if len(target_fields) > 0: + dataset = dataset.remove_columns(list(target_fields)) + analyzer.run(dataset, skip_return=True) + + def insight_mining(self, pval_th=0.05): + """ + Mining the insights from the OP-wise analysis results. For now, we use + T-Test to check the significance of stats/meta changes before and after + each OP processing. If the p-value is less than a given threshold + (usually 0.05), we think the stats/meta changes are significant. The + insight mining results will be stored in the file + `{work_dir}/insight_mining/insight_mining.json`. + + :param pval_th: the threshold of p-value. + """ + work_dir = os.path.join(self.cfg.work_dir, 'insight_mining') + res_order = [ + d for d in os.listdir(work_dir) + if os.path.isdir(os.path.join(work_dir, d)) + ] + res_order.sort() + + # collect analysis results + analysis_results = {} + for res_dir in res_order: + res = Dataset.from_json( + os.path.join(work_dir, res_dir, + f'{res_dir}_stats.jsonl')).flatten() + analysis_results[res_dir] = res + + # distribution change significance analysis + ttest_measure = RelatedTTestMeasure() + + sig_res = {} + # i = 0 is the original dataset + for i in range(1, len(res_order)): + prev_res = analysis_results[res_order[i - 1]] + curr_res = analysis_results[res_order[i]] + + # only consider common stats and meta + common_features = list( + set(prev_res.features).intersection(set(curr_res.features))) + curr_sig_res = {} + for feat in common_features: + ttest_res = ttest_measure(prev_res[feat], curr_res[feat]) + curr_sig_res[feat] = { + 't-statistic (standardized mean difference)': + ttest_res.statistic, + 'p-value': ttest_res.pvalue, + 'significant': + True if ttest_res.pvalue < pval_th else False, + } + sig_res[res_order[i]] = curr_sig_res + + with open(os.path.join(work_dir, 'insight_mining.json'), 'w') as out: + json.dump(sig_res, out) diff --git a/data_juicer/core/analyzer.py b/data_juicer/core/analyzer.py index 63e512d41..d9ac586e9 100644 --- a/data_juicer/core/analyzer.py +++ b/data_juicer/core/analyzer.py @@ -1,6 +1,7 @@ import os -from typing import Optional +from typing import Optional, Union +from datasets import Dataset from jsonargparse import Namespace from loguru import logger from pydantic import PositiveInt @@ -8,11 +9,12 @@ from data_juicer.analysis import ColumnWiseAnalysis, OverallAnalysis from data_juicer.config import init_configs from data_juicer.format import load_formatter -from data_juicer.ops import Filter, load_ops +from data_juicer.ops import NON_STATS_FILTERS, TAGGING_OPS, Filter, load_ops from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils import cache_utils from .adapter import Adapter +from .data import NestedDataset from .exporter import Exporter @@ -71,22 +73,27 @@ def __init__(self, cfg: Optional[Namespace] = None): self.analysis_path = os.path.join(self.cfg.work_dir, 'analysis') def run(self, + dataset: Union[Dataset, NestedDataset] = None, load_data_np: Optional[PositiveInt] = None, skip_export: bool = False, skip_return: bool = False): """ Running the dataset analysis pipeline. + :param dataset: a Dataset object to be analyzed. :param load_data_np: number of workers when loading the dataset. :param skip_export: whether export the results into disk :param skip_return: skip return for API called. :return: analyzed dataset. """ # 1. format data - logger.info('Loading dataset from data formatter...') if load_data_np is None: load_data_np = self.cfg.np - dataset = self.formatter.load_dataset(load_data_np, self.cfg) + if dataset is None: + logger.info('Loading dataset from data formatter...') + dataset = self.formatter.load_dataset(load_data_np, self.cfg) + else: + logger.info(f'Using existing dataset {dataset}') if self.cfg.auto: # if it's auto analysis, only analyze for a minor part of the input # dataset to save time and computing resource @@ -111,16 +118,26 @@ def run(self, logger.info('Computing the stats of dataset...') stats_collected = False for op in ops: - if isinstance(op, Filter): + if isinstance(op, Filter) \ + and op._name not in NON_STATS_FILTERS.modules: original_process = op.process op.process = None - dataset = dataset.process(op, work_dir=self.work_dir) + dataset = dataset.process(op, + work_dir=self.work_dir, + open_monitor=self.cfg.open_monitor) op.process = original_process stats_collected = True + elif op._name in TAGGING_OPS.modules: + dataset = dataset.process(op, + work_dir=self.work_dir, + open_monitor=self.cfg.open_monitor) + stats_collected = True if not stats_collected: - logger.warning('No stats collected. Please add some Filter ops to ' - 'the process list in configs.') - return dataset + logger.warning( + 'No stats/meta collected. Please add some Filter OPs or ' + 'Tagging OPs to the process list in configs.') + if not skip_return: + return dataset # 3. data export logger.info('Exporting dataset to disk...') diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index 361f6e8a0..d0f8083e1 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -172,6 +172,7 @@ def process( exporter=None, checkpointer=None, tracer=None, + adapter=None, open_monitor=True, ): if operators is None: @@ -185,9 +186,19 @@ def process( if open_monitor: resource_util_list = [] + # whether to enable insight mining + enable_insight_mining = adapter.enable_insight_mining \ + if adapter else False + # record the analysis results of the original dataset + if enable_insight_mining: + logger.info('Analyze small batch for the original dataset for ' + 'insight mining...') + adapter.analyze_small_batch(self, '0_original') + dataset = self + op_num = len(operators) try: - for op in operators: + for idx, op in enumerate(operators, start=1): mp_context = ['forkserver', 'spawn'] if ( op.use_cuda() or op._name in unforkable_operators) else None @@ -211,8 +222,16 @@ def process( if open_monitor: resource_util_list.append(resource_util_per_op) end = time() - logger.info(f'OP [{op._name}] Done in {end - start:.3f}s. ' - f'Left {len(dataset)} samples.') + logger.info( + f'[{idx}/{op_num}] OP [{op._name}] Done in ' + f'{end - start:.3f}s. Left {len(dataset)} samples.') + + # record the analysis results of the current dataset + if enable_insight_mining: + logger.info( + f'Analyze small batch for the current dataset after ' + f'OP [{op._name}] for insight mining...') + adapter.analyze_small_batch(dataset, f'{idx}_{op._name}') except: # noqa: E722 logger.error(f'An error occurred during Op [{op._name}].') traceback.print_exc() @@ -223,6 +242,7 @@ def process( 'last op...') dataset.cleanup_cache_files() checkpointer.save_ckpt(dataset) + # make summarization on the monitor results if work_dir and open_monitor: # get the analyzed version resource_util_list = Monitor.analyze_resource_util_list( @@ -234,6 +254,10 @@ def process( json.dump(resource_util_list, out) Monitor.draw_resource_util_graph(resource_util_list, monitor_dir) + # make summarization on the insight mining results + if work_dir and enable_insight_mining: + logger.info('Insight mining for each OP...') + adapter.insight_mining() return dataset def update_args(self, args, kargs, is_filter=False): diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py index f78059247..7f0d93a66 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor.py @@ -199,6 +199,7 @@ def run(self, exporter=self.exporter, checkpointer=self.ckpt_manager, tracer=self.tracer, + adapter=self.adapter, open_monitor=self.cfg.open_monitor, ) tend = time() diff --git a/data_juicer/core/exporter.py b/data_juicer/core/exporter.py index 72b555d34..dbdb4fb9f 100644 --- a/data_juicer/core/exporter.py +++ b/data_juicer/core/exporter.py @@ -106,10 +106,15 @@ def _export_impl(self, dataset, export_path, suffix, export_stats=True): :param export_stats: whether to export stats of dataset. :return: """ - if Fields.stats in dataset.features and export_stats: + if export_stats: # export stats of datasets into a single file. logger.info('Exporting computed stats into a single file...') - ds_stats = dataset.select_columns(Fields.stats) + export_columns = [] + if Fields.stats in dataset.features: + export_columns.append(Fields.stats) + if Fields.meta in dataset.features: + export_columns.append(Fields.meta) + ds_stats = dataset.select_columns(export_columns) stats_file = export_path.replace('.' + suffix, '_stats.jsonl') Exporter.to_jsonl( ds_stats, @@ -119,7 +124,7 @@ def _export_impl(self, dataset, export_path, suffix, export_stats=True): if self.export_ds: # fetch the corresponding export method according to the suffix if not self.keep_stats_in_res_ds: - extra_fields = {Fields.stats} + extra_fields = {Fields.stats, Fields.meta} feature_fields = set(dataset.features.keys()) removed_fields = extra_fields.intersection(feature_fields) dataset = dataset.remove_columns(removed_fields) diff --git a/data_juicer/core/monitor.py b/data_juicer/core/monitor.py index 0210e3732..d5fdee241 100644 --- a/data_juicer/core/monitor.py +++ b/data_juicer/core/monitor.py @@ -15,7 +15,13 @@ def resource_monitor(mdict, interval): while True: this_states.append(Monitor.monitor_current_resources()) time.sleep(interval) - if mdict['stop']: + try: + stop_sign = mdict['stop'] + except (BrokenPipeError, FileNotFoundError): + # mdict crushes due to the main process is terminated already, + # which is not the fault here + return + if stop_sign: break mdict['resource'] = this_states diff --git a/data_juicer/ops/__init__.py b/data_juicer/ops/__init__.py index e02e10efa..2ab622266 100644 --- a/data_juicer/ops/__init__.py +++ b/data_juicer/ops/__init__.py @@ -1,6 +1,7 @@ from . import aggregator, deduplicator, filter, grouper, mapper, selector -from .base_op import (OPERATORS, UNFORKABLE, Aggregator, Deduplicator, Filter, - Grouper, Mapper, Selector) +from .base_op import (NON_STATS_FILTERS, OPERATORS, TAGGING_OPS, UNFORKABLE, + Aggregator, Deduplicator, Filter, Grouper, Mapper, + Selector) from .load import load_ops __all__ = [ diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 2091a867e..39e23d8e9 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -14,6 +14,8 @@ OPERATORS = Registry('Operators') UNFORKABLE = Registry('Unforkable') +NON_STATS_FILTERS = Registry('Non-stats Filters') +TAGGING_OPS = Registry('Tagging Operators') def convert_list_dict_to_dict_list(samples): @@ -223,6 +225,18 @@ def run(self, dataset): from data_juicer.core.data import NestedDataset if not isinstance(dataset, NestedDataset): dataset = NestedDataset(dataset) + # add meta field for OPs that produce tags + if self._name in TAGGING_OPS.modules \ + and Fields.meta not in dataset.features: + from data_juicer.core.data import add_same_content_to_new_column + dataset = dataset.map(add_same_content_to_new_column, + fn_kwargs={ + 'new_column_name': Fields.meta, + 'initial_value': {} + }, + num_proc=self.runtime_np(), + batch_size=self.batch_size, + desc='Adding new column for meta') if self.index_key is not None: def add_index(sample, idx): @@ -404,7 +418,9 @@ def process_single(self, sample): def run(self, dataset, *, exporter=None, tracer=None, reduce=True): dataset = super(Filter, self).run(dataset) - if Fields.stats not in dataset.features: + # add stats field for Filters that produce stats + if self._name not in NON_STATS_FILTERS.modules \ + and Fields.stats not in dataset.features: from data_juicer.core.data import add_same_content_to_new_column dataset = dataset.map(add_same_content_to_new_column, fn_kwargs={ diff --git a/data_juicer/ops/filter/specified_field_filter.py b/data_juicer/ops/filter/specified_field_filter.py index 86aff2426..41addf8da 100644 --- a/data_juicer/ops/filter/specified_field_filter.py +++ b/data_juicer/ops/filter/specified_field_filter.py @@ -1,9 +1,12 @@ from typing import List -from ..base_op import OPERATORS, Filter +from ..base_op import NON_STATS_FILTERS, OPERATORS, Filter +OP_NAME = 'specified_field_filter' -@OPERATORS.register_module('specified_field_filter') + +@NON_STATS_FILTERS.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) class SpecifiedFieldFilter(Filter): """ Filter based on specified field information. diff --git a/data_juicer/ops/filter/specified_numeric_field_filter.py b/data_juicer/ops/filter/specified_numeric_field_filter.py index 693be3392..c7a1d301a 100644 --- a/data_juicer/ops/filter/specified_numeric_field_filter.py +++ b/data_juicer/ops/filter/specified_numeric_field_filter.py @@ -1,6 +1,6 @@ import sys -from ..base_op import OPERATORS, Filter +from ..base_op import NON_STATS_FILTERS, OPERATORS, Filter def is_number(s): @@ -13,7 +13,11 @@ def is_number(s): return False -@OPERATORS.register_module('specified_numeric_field_filter') +OP_NAME = 'specified_numeric_field_filter' + + +@NON_STATS_FILTERS.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) class SpecifiedNumericFieldFilter(Filter): """ Filter based on specified numeric field information. diff --git a/data_juicer/ops/filter/suffix_filter.py b/data_juicer/ops/filter/suffix_filter.py index ea7868399..7aaca53a7 100644 --- a/data_juicer/ops/filter/suffix_filter.py +++ b/data_juicer/ops/filter/suffix_filter.py @@ -2,10 +2,13 @@ from data_juicer.utils.constant import Fields -from ..base_op import OPERATORS, Filter +from ..base_op import NON_STATS_FILTERS, OPERATORS, Filter +OP_NAME = 'suffix_filter' -@OPERATORS.register_module('suffix_filter') + +@NON_STATS_FILTERS.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) class SuffixFilter(Filter): """Filter to keep samples with specified suffix.""" diff --git a/data_juicer/ops/filter/video_tagging_from_frames_filter.py b/data_juicer/ops/filter/video_tagging_from_frames_filter.py index 8872aab32..2436d886c 100644 --- a/data_juicer/ops/filter/video_tagging_from_frames_filter.py +++ b/data_juicer/ops/filter/video_tagging_from_frames_filter.py @@ -5,7 +5,8 @@ from data_juicer.utils.constant import Fields -from ..base_op import OPERATORS, UNFORKABLE, Filter +from ..base_op import (NON_STATS_FILTERS, OPERATORS, TAGGING_OPS, UNFORKABLE, + Filter) from ..mapper.video_tagging_from_frames_mapper import \ VideoTaggingFromFramesMapper from ..op_fusion import LOADED_VIDEOS @@ -13,6 +14,8 @@ OP_NAME = 'video_tagging_from_frames_filter' +@NON_STATS_FILTERS.register_module(OP_NAME) +@TAGGING_OPS.register_module(OP_NAME) @UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) @LOADED_VIDEOS.register_module(OP_NAME) @@ -91,7 +94,7 @@ def compute_stats_single(self, sample, rank=None, context=False): return sample def process_single(self, sample, rank=None): - video_tags = sample[self.tag_field_name] + video_tags = sample[Fields.meta][self.tag_field_name] if len(video_tags) <= 0: return True diff --git a/data_juicer/ops/mapper/image_tagging_mapper.py b/data_juicer/ops/mapper/image_tagging_mapper.py index e3fc46f1b..dc2099b78 100644 --- a/data_juicer/ops/mapper/image_tagging_mapper.py +++ b/data_juicer/ops/mapper/image_tagging_mapper.py @@ -7,7 +7,7 @@ from data_juicer.utils.mm_utils import load_data_with_context, load_image from data_juicer.utils.model_utils import get_model, prepare_model -from ..base_op import OPERATORS, UNFORKABLE, Mapper +from ..base_op import OPERATORS, TAGGING_OPS, UNFORKABLE, Mapper from ..op_fusion import LOADED_IMAGES torch = LazyLoader('torch', 'torch') @@ -16,6 +16,7 @@ OP_NAME = 'image_tagging_mapper' +@TAGGING_OPS.register_module(OP_NAME) @UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) @LOADED_IMAGES.register_module(OP_NAME) @@ -47,12 +48,13 @@ def __init__(self, def process_single(self, sample, rank=None, context=False): # check if it's generated already - if self.tag_field_name in sample: + if self.tag_field_name in sample[Fields.meta]: return sample # there is no image in this sample if self.image_key not in sample or not sample[self.image_key]: - sample[self.tag_field_name] = np.array([[]], dtype=np.str_) + sample[Fields.meta][self.tag_field_name] = np.array([[]], + dtype=np.str_) return sample # load images @@ -75,5 +77,5 @@ def process_single(self, sample, rank=None, context=False): sorted_word_list = [item for item, _ in word_count.most_common()] image_tags.append(np.array(sorted_word_list, dtype=np.str_)) - sample[self.tag_field_name] = image_tags + sample[Fields.meta][self.tag_field_name] = image_tags return sample diff --git a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py index 2c32093a5..7302953f2 100644 --- a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py @@ -6,13 +6,14 @@ from data_juicer.utils.mm_utils import extract_audio_from_video from data_juicer.utils.model_utils import get_model, prepare_model -from ..base_op import OPERATORS, Mapper +from ..base_op import OPERATORS, TAGGING_OPS, Mapper torch = LazyLoader('torch', 'torch') OP_NAME = 'video_tagging_from_audio_mapper' +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class VideoTaggingFromAudioMapper(Mapper): """Mapper to generate video tags from audio streams extracted by video @@ -50,12 +51,13 @@ def __init__(self, def process_single(self, sample, rank=None): # check if it's generated already - if self.tag_field_name in sample: + if self.tag_field_name in sample[Fields.meta]: return sample # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: - sample[self.tag_field_name] = np.array([], dtype=np.str_) + sample[Fields.meta][self.tag_field_name] = np.array([], + dtype=np.str_) return sample # load video paths @@ -90,5 +92,6 @@ def process_single(self, sample, rank=None): predicted_tag_id = torch.argmax(logits, dim=-1).item() predicted_tag = model.config.id2label[predicted_tag_id] video_audio_tags.append(predicted_tag) - sample[self.tag_field_name] = np.array(video_audio_tags, dtype=np.str_) + sample[Fields.meta][self.tag_field_name] = np.array(video_audio_tags, + dtype=np.str_) return sample diff --git a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py index d4995d3f6..31927e1b2 100644 --- a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py @@ -10,7 +10,7 @@ load_data_with_context, load_video) from data_juicer.utils.model_utils import get_model, prepare_model -from ..base_op import OPERATORS, UNFORKABLE, Mapper +from ..base_op import OPERATORS, TAGGING_OPS, UNFORKABLE, Mapper from ..op_fusion import LOADED_VIDEOS ram = LazyLoader('ram', 'ram') @@ -19,6 +19,7 @@ OP_NAME = 'video_tagging_from_frames_mapper' +@TAGGING_OPS.register_module(OP_NAME) @UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) @LOADED_VIDEOS.register_module(OP_NAME) @@ -73,12 +74,13 @@ def __init__(self, def process_single(self, sample, rank=None, context=False): # check if it's generated already - if self.tag_field_name in sample: + if self.tag_field_name in sample[Fields.meta]: return sample # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: - sample[self.tag_field_name] = np.array([[]], dtype=np.str_) + sample[Fields.meta][self.tag_field_name] = np.array([[]], + dtype=np.str_) return sample # load videos @@ -115,5 +117,5 @@ def process_single(self, sample, rank=None, context=False): for vid_key in videos: close_video(videos[vid_key]) - sample[self.tag_field_name] = video_tags + sample[Fields.meta][self.tag_field_name] = video_tags return sample diff --git a/data_juicer/utils/cache_utils.py b/data_juicer/utils/cache_utils.py index 7d815db2c..51138d7ed 100644 --- a/data_juicer/utils/cache_utils.py +++ b/data_juicer/utils/cache_utils.py @@ -1,4 +1,7 @@ import os +from functools import wraps + +from datasets import disable_caching, enable_caching, is_caching_enabled # Default cache location DEFAULT_CACHE_HOME = '~/.cache' @@ -21,3 +24,47 @@ DEFAULT_DATA_JUICER_MODELS_CACHE) CACHE_COMPRESS = None + + +class DatasetCacheControl: + """Define a range that change the cache state temporarily.""" + + def __init__(self, on: bool = False): + self.on = on + + def __enter__(self): + """ + Record the original cache state and turn it to the target state. + """ + self.previous_state = is_caching_enabled() + if self.on: + enable_caching() + else: + disable_caching() + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Restore the original cache state. + """ + if self.previous_state: + enable_caching() + else: + disable_caching() + + +def dataset_cache_control(on): + """ + A more easy-to-use decorator for functions that need to control the cache + state temporarily. + """ + + def dataset_cache_decorator(func): + + @wraps(func) + def wrapped_function(*args, **kwargs): + with DatasetCacheControl(on=on): + return func(*args, **kwargs) + + return wrapped_function + + return dataset_cache_decorator diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 922d44c8b..30686693e 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -16,13 +16,17 @@ class Fields(object): context = DEFAULT_PREFIX + 'context__' suffix = DEFAULT_PREFIX + 'suffix__' - video_frames = DEFAULT_PREFIX + 'video_frames__' + # tags in meta # video_frame_tags video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__' + # video_audio_tags video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__' # image_tags image_tags = DEFAULT_PREFIX + 'image_tags__' + # video_frames + video_frames = DEFAULT_PREFIX + 'video_frames__' + # the name of the original file from which this sample was derived. source_file = DEFAULT_PREFIX + 'source_file__' diff --git a/tests/ops/filter/test_video_tagging_from_frames_filter.py b/tests/ops/filter/test_video_tagging_from_frames_filter.py index bc4f67fb4..4018136ec 100644 --- a/tests/ops/filter/test_video_tagging_from_frames_filter.py +++ b/tests/ops/filter/test_video_tagging_from_frames_filter.py @@ -6,6 +6,7 @@ from data_juicer.ops.filter.video_tagging_from_frames_filter import \ VideoTaggingFromFramesFilter from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.constant import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase class VideoTaggingFromFramesFilterTest(DataJuicerTestCaseBase): @@ -21,8 +22,11 @@ def _run_video_tagging_from_frames_filter(self, target_list, num_proc=1): dataset = Dataset.from_list(source_list) - dataset = dataset.map(op.compute_stats) - dataset = dataset.filter(op.process) + if Fields.meta not in dataset.features: + dataset = dataset.add_column(name=Fields.meta, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=num_proc) + dataset = dataset.filter(op.process, num_proc=num_proc) dataset = dataset.select_columns(column_names=['text', 'videos']) res_list = dataset.to_list() self.assertEqual(res_list, target_list) diff --git a/tests/ops/mapper/test_image_tagging_mapper.py b/tests/ops/mapper/test_image_tagging_mapper.py index 9ec3e4d22..d2bbddec2 100644 --- a/tests/ops/mapper/test_image_tagging_mapper.py +++ b/tests/ops/mapper/test_image_tagging_mapper.py @@ -24,6 +24,9 @@ def _run_image_tagging_mapper(self, target_list, num_proc=1): dataset = Dataset.from_list(source_list) + if Fields.meta not in dataset.features: + dataset = dataset.add_column(name=Fields.meta, + column=[{}] * dataset.num_rows) dataset = dataset.map(op.process, num_proc=num_proc, with_rank=True) res_list = dataset.to_list() self.assertEqual(res_list, target_list) @@ -38,23 +41,26 @@ def test(self): }] tgt_list = [{ 'images': [self.img1_path], - Fields.image_tags: [[ - 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', - 'chair', 'pillar', 'comfort', 'side table', 'floor', - 'hardwood floor', 'headboard', 'linen', 'mattress', - 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', - 'stool', 'white', 'window', 'wood floor']], + Fields.meta: { + Fields.image_tags: [[ + 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', + 'chair', 'pillar', 'comfort', 'side table', 'floor', + 'hardwood floor', 'headboard', 'linen', 'mattress', + 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', + 'stool', 'white', 'window', 'wood floor']]}, }, { 'images': [self.img2_path], - Fields.image_tags: [[ - 'advertisement', 'back', 'bus', 'car', 'city bus', - 'city street', 'curb', 'decker bus', 'drive', 'license plate', - 'road', 'street scene', 'tour bus', 'travel', 'white']], + Fields.meta: { + Fields.image_tags: [[ + 'advertisement', 'back', 'bus', 'car', 'city bus', + 'city street', 'curb', 'decker bus', 'drive', 'license plate', + 'road', 'street scene', 'tour bus', 'travel', 'white']]}, }, { 'images': [self.img3_path], - Fields.image_tags: [[ - 'alley', 'black', 'building', 'catch', 'person', 'pavement', - 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']], + Fields.meta: { + Fields.image_tags: [[ + 'alley', 'black', 'building', 'catch', 'person', 'pavement', + 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']]}, }] op = ImageTaggingMapper() self._run_image_tagging_mapper(op, ds_list, tgt_list) @@ -67,13 +73,15 @@ def test_no_images(self): }] tgt_list = [{ 'images': [], - Fields.image_tags: [[]], + Fields.meta: { + Fields.image_tags: [[]]}, }, { 'images': [self.img2_path], - Fields.image_tags: [[ - 'advertisement', 'back', 'bus', 'car', 'city bus', - 'city street', 'curb', 'decker bus', 'drive', 'license plate', - 'road', 'street scene', 'tour bus', 'travel', 'white']], + Fields.meta: { + Fields.image_tags: [[ + 'advertisement', 'back', 'bus', 'car', 'city bus', + 'city street', 'curb', 'decker bus', 'drive', 'license plate', + 'road', 'street scene', 'tour bus', 'travel', 'white']]}, }] op = ImageTaggingMapper() self._run_image_tagging_mapper(op, ds_list, tgt_list) @@ -90,23 +98,26 @@ def test_specified_tag_field_name(self): }] tgt_list = [{ 'images': [self.img1_path], - tag_field_name: [[ - 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', - 'chair', 'pillar', 'comfort', 'side table', 'floor', - 'hardwood floor', 'headboard', 'linen', 'mattress', - 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', - 'stool', 'white', 'window', 'wood floor']], + Fields.meta: { + tag_field_name: [[ + 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', + 'chair', 'pillar', 'comfort', 'side table', 'floor', + 'hardwood floor', 'headboard', 'linen', 'mattress', + 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', + 'stool', 'white', 'window', 'wood floor']]}, }, { 'images': [self.img2_path], - tag_field_name: [[ - 'advertisement', 'back', 'bus', 'car', 'city bus', - 'city street', 'curb', 'decker bus', 'drive', 'license plate', - 'road', 'street scene', 'tour bus', 'travel', 'white']], + Fields.meta: { + tag_field_name: [[ + 'advertisement', 'back', 'bus', 'car', 'city bus', + 'city street', 'curb', 'decker bus', 'drive', 'license plate', + 'road', 'street scene', 'tour bus', 'travel', 'white']]}, }, { 'images': [self.img3_path], - tag_field_name: [[ - 'alley', 'black', 'building', 'catch', 'person', 'pavement', - 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']], + Fields.meta: { + tag_field_name: [[ + 'alley', 'black', 'building', 'catch', 'person', 'pavement', + 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']]}, }] op = ImageTaggingMapper(tag_field_name=tag_field_name) self._run_image_tagging_mapper(op, ds_list, tgt_list) @@ -126,23 +137,26 @@ def test_multi_process(self): }] tgt_list = [{ 'images': [self.img1_path], - Fields.image_tags: [[ - 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', - 'chair', 'pillar', 'comfort', 'side table', 'floor', - 'hardwood floor', 'headboard', 'linen', 'mattress', - 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', - 'stool', 'white', 'window', 'wood floor']], + Fields.meta: { + Fields.image_tags: [[ + 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', + 'chair', 'pillar', 'comfort', 'side table', 'floor', + 'hardwood floor', 'headboard', 'linen', 'mattress', + 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', + 'stool', 'white', 'window', 'wood floor']]}, }, { 'images': [self.img2_path], - Fields.image_tags: [[ - 'advertisement', 'back', 'bus', 'car', 'city bus', - 'city street', 'curb', 'decker bus', 'drive', 'license plate', - 'road', 'street scene', 'tour bus', 'travel', 'white']], + Fields.meta: { + Fields.image_tags: [[ + 'advertisement', 'back', 'bus', 'car', 'city bus', + 'city street', 'curb', 'decker bus', 'drive', 'license plate', + 'road', 'street scene', 'tour bus', 'travel', 'white']]}, }, { 'images': [self.img3_path], - Fields.image_tags: [[ - 'alley', 'black', 'building', 'catch', 'person', 'pavement', - 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']], + Fields.meta: { + Fields.image_tags: [[ + 'alley', 'black', 'building', 'catch', 'person', 'pavement', + 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']]}, }] op = ImageTaggingMapper() self._run_image_tagging_mapper(op, diff --git a/tests/ops/mapper/test_video_tagging_from_audio_mapper.py b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py index 8bbf05933..00a376170 100644 --- a/tests/ops/mapper/test_video_tagging_from_audio_mapper.py +++ b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py @@ -31,8 +31,11 @@ def _run_video_tagging_from_audio_mapper(self, tag_field_name=Fields.video_audio_tags, num_proc=1): dataset = Dataset.from_list(source_list) + if Fields.meta not in dataset.features: + dataset = dataset.add_column(name=Fields.meta, + column=[{}] * dataset.num_rows) dataset = dataset.map(op.process, num_proc=num_proc) - res_list = dataset.select_columns([tag_field_name])[tag_field_name] + res_list = dataset.flatten().select_columns([f'{Fields.meta}.{tag_field_name}'])[f'{Fields.meta}.{tag_field_name}'] self.assertEqual(res_list, target_list) def test(self): diff --git a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py index 4484df754..31fc04c3b 100644 --- a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py +++ b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py @@ -25,6 +25,9 @@ def _run_video_tagging_from_frames_mapper(self, target_list, num_proc=1): dataset = Dataset.from_list(source_list) + if Fields.meta not in dataset.features: + dataset = dataset.add_column(name=Fields.meta, + column=[{}] * dataset.num_rows) dataset = dataset.map(op.process, num_proc=num_proc) res_list = dataset.to_list() self.assertEqual(res_list, target_list) @@ -46,30 +49,33 @@ def test(self): 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path], - Fields.video_frame_tags: [[ - 'animal', 'ray', 'text', 'writing', 'yellow', 'game', - 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', - 'sky' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ]]} }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], - Fields.video_frame_tags: [[ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', - 'ball', 'person' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ]]} }, { 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], - Fields.video_frame_tags: [[ - 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', - 'conversation', 'round table', 'closet', 'computer', 'girl', - 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', - 'selfie', 'stand' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' + ]]} }] op = VideoTaggingFromFramesMapper() self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list) @@ -87,16 +93,18 @@ def test_no_video(self): 'text': f'白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [], - Fields.video_frame_tags: [[]] + Fields.meta: { + Fields.video_frame_tags: [[]]} }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], - Fields.video_frame_tags: [[ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', - 'ball', 'person' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ]]} }] op = VideoTaggingFromFramesMapper() self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list) @@ -120,30 +128,33 @@ def test_specified_tag_field_name(self): 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path], - tag_field_name: [[ - 'animal', 'ray', 'text', 'writing', 'yellow', 'game', - 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', - 'sky' - ]] + Fields.meta: { + tag_field_name: [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ]]} }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], - tag_field_name: [[ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', - 'ball', 'person' - ]] + Fields.meta: { + tag_field_name: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ]]} }, { 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], - tag_field_name: [[ - 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', - 'conversation', 'round table', 'closet', 'computer', 'girl', - 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', - 'selfie', 'stand' - ]] + Fields.meta: { + tag_field_name: [[ + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' + ]]} }] op = VideoTaggingFromFramesMapper(tag_field_name=tag_field_name) self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list) @@ -165,30 +176,33 @@ def test_uniform(self): 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path], - Fields.video_frame_tags: [[ - 'cartoon', 'animal', 'anime', 'game', 'screenshot', - 'video game', 'cartoon character', 'robe', 'ray', 'text', - 'writing', 'yellow', 'doll', 'tail', 'sky', 'person']] + Fields.meta: { + Fields.video_frame_tags: [[ + 'cartoon', 'animal', 'anime', 'game', 'screenshot', + 'video game', 'cartoon character', 'robe', 'ray', 'text', + 'writing', 'yellow', 'doll', 'tail', 'sky', 'person']]} }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], - Fields.video_frame_tags: [[ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'hand', 'catch', 'bulletin board', 'Wii', 'cotton candy', - 'tennis racket', 'blind', 'game controller', 'remote', 'stand', - 'video game', 'Wii controller', 'play', 'baseball uniform', - 'toy', 'green']] + Fields.meta: { + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'hand', 'catch', 'bulletin board', 'Wii', 'cotton candy', + 'tennis racket', 'blind', 'game controller', 'remote', 'stand', + 'video game', 'Wii controller', 'play', 'baseball uniform', + 'toy', 'green']]} }, { 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], - Fields.video_frame_tags: [[ - 'table', 'sit', 'woman', 'bookshelf', 'conversation', 'person', - 'round table', 'computer', 'girl', 'man', 'closet', 'laptop', - 'stand', 'computer screen', 'talk', 'room', 'stool', 'hand', - 'point' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'table', 'sit', 'woman', 'bookshelf', 'conversation', 'person', + 'round table', 'computer', 'girl', 'man', 'closet', 'laptop', + 'stand', 'computer screen', 'talk', 'room', 'stool', 'hand', + 'point' + ]]} }] op = VideoTaggingFromFramesMapper(frame_sampling_method='uniform', frame_num=10) @@ -216,30 +230,33 @@ def test_multi_process(self): 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path], - Fields.video_frame_tags: [[ - 'animal', 'ray', 'text', 'writing', 'yellow', 'game', - 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', - 'sky' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ]]} }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], - Fields.video_frame_tags: [[ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', - 'ball', 'person' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ]]} }, { 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], - Fields.video_frame_tags: [[ - 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', - 'conversation', 'round table', 'closet', 'computer', 'girl', - 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', - 'selfie', 'stand' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' + ]]} }] op = VideoTaggingFromFramesMapper() self._run_video_tagging_from_frames_mapper(op, @@ -268,44 +285,47 @@ def test_multi_chunk(self): 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。', 'videos': [self.vid1_path, self.vid2_path], - Fields.video_frame_tags: - [[ - 'animal', 'ray', 'text', 'writing', 'yellow', 'game', - 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', - 'sky' - ], [ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', - 'ball', 'person' - ]] + Fields.meta: { + Fields.video_frame_tags: + [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ], [ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ]]} }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid2_path, self.vid3_path], - Fields.video_frame_tags: [[ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', - 'ball', 'person' - ], [ - 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', - 'conversation', 'round table', 'closet', 'computer', 'girl', - 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', - 'selfie', 'stand' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ], [ + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' + ]]} }, { 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid1_path, self.vid3_path], - Fields.video_frame_tags: [[ - 'animal', 'ray', 'text', 'writing', 'yellow', 'game', - 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', - 'sky' - ], [ - 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', - 'conversation', 'round table', 'closet', 'computer', 'girl', - 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', - 'selfie', 'stand' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ], [ + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' + ]]} }] op = VideoTaggingFromFramesMapper() self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list)