From 103562ff5c79bc244dca66c68c98782d6f2376a6 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Tue, 14 Nov 2023 17:49:55 +0800 Subject: [PATCH] + Add new OP: image_shape_filter --- configs/config_all.yaml | 8 +- data_juicer/ops/filter/__init__.py | 12 +- data_juicer/ops/filter/image_shape_filter.py | 107 +++++++++++++++ data_juicer/utils/constant.py | 2 + demos/overview_scan/app.py | 3 +- docs/Operators.md | 3 +- docs/Operators_ZH.md | 3 +- .../filter/test_image_aspect_ratio_filter.py | 4 +- tests/ops/filter/test_image_shape_filter.py | 127 ++++++++++++++++++ 9 files changed, 256 insertions(+), 13 deletions(-) create mode 100644 data_juicer/ops/filter/image_shape_filter.py create mode 100644 tests/ops/filter/test_image_shape_filter.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 708db5c22..d2c108306 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -1,6 +1,6 @@ # Process config example including: # - all global arguments -# - all ops and their default arguments +# - all ops and their arguments # global parameters project_name: 'all' # project name for distinguish your configs @@ -120,6 +120,12 @@ process: min_ratio: 0.333 # the min aspect ratio of filter range max_ratio: 3.0 # the max aspect ratio of filter range any_or_all: any # keep this sample when any/all images meet the filter condition + - image_shape_filter: # filter samples according to the widths and heights of images in them + min_width: 200 # the min width of width filter range + max_width: 5000 # the max width of width filter range + min_height: 200 # the min height of height filter range + max_height: 5000 # the max height of height filter range + any_or_all: any # keep this sample when any/all images meet the filter condition - language_id_score_filter: # filter text in specific language with language scores larger than a specific max value lang: en # keep text in what language min_score: 0.8 # the min language scores to filter text diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index c9332eea0..8b738a02a 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -1,8 +1,8 @@ from . import (alphanumeric_filter, average_line_length_filter, character_repetition_filter, flagged_words_filter, - image_aspect_ratio_filter, language_id_score_filter, - maximum_line_length_filter, perplexity_filter, - special_characters_filter, specified_field_filter, - specified_numeric_field_filter, stopwords_filter, suffix_filter, - text_length_filter, token_num_filter, word_num_filter, - word_repetition_filter) + image_aspect_ratio_filter, image_shape_filter, + language_id_score_filter, maximum_line_length_filter, + perplexity_filter, special_characters_filter, + specified_field_filter, specified_numeric_field_filter, + stopwords_filter, suffix_filter, text_length_filter, + token_num_filter, word_num_filter, word_repetition_filter) diff --git a/data_juicer/ops/filter/image_shape_filter.py b/data_juicer/ops/filter/image_shape_filter.py new file mode 100644 index 000000000..a2db1e271 --- /dev/null +++ b/data_juicer/ops/filter/image_shape_filter.py @@ -0,0 +1,107 @@ +import sys + +import numpy as np +from jsonargparse.typing import PositiveInt + +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import load_image + +from ..base_op import OPERATORS, Filter +from ..op_fusion import LOADED_IMAGES + + +@OPERATORS.register_module('image_shape_filter') +@LOADED_IMAGES.register_module('image_shape_filter') +class ImageShapeFilter(Filter): + """Filter to keep samples with image shape (w, h) within specific ranges. + """ + + def __init__(self, + min_width: PositiveInt = 1, + max_width: PositiveInt = sys.maxsize, + min_height: PositiveInt = 1, + max_height: PositiveInt = sys.maxsize, + any_or_all: str = 'any', + *args, + **kwargs): + """ + Initialization method. + + :param min_width: The min width to keep samples. + :param max_width: The max width to keep samples. + :param min_height: The min height to keep samples. + :param max_height: The max height to keep samples. + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all images. 'any': keep this sample if any images meet the + condition. 'all': keep this sample only if all images meet the + condition. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_width = min_width + self.max_width = max_width + self.min_height = min_height + self.max_height = max_height + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + + def compute_stats(self, sample, context=False): + # check if it's computed already + if StatsKeys.image_width in sample[Fields.stats] \ + and StatsKeys.image_height in sample[Fields.stats]: + return sample + + # there is no image in this sample + if self.image_key not in sample or not sample[self.image_key]: + sample[Fields.stats][StatsKeys.image_width] = np.array( + [], dtype=np.int64) + sample[Fields.stats][StatsKeys.image_height] = np.array( + [], dtype=np.int64) + return sample + + # load images + loaded_image_keys = sample[self.image_key] + images = {} + for loaded_image_key in loaded_image_keys: + if context and loaded_image_key in sample[Fields.context]: + # load from context + images[loaded_image_key] = sample[ + Fields.context][loaded_image_key] + else: + if loaded_image_key not in images: + # avoid load the same images + image = load_image(loaded_image_key) + images[loaded_image_key] = image + if context: + # store the image data into context + sample[Fields.context][loaded_image_key] = image + + # get width and height for each image + whs = {key: (images[key].width, images[key].height) for key in images} + sample[Fields.stats][StatsKeys.image_width] = [ + whs[key][0] for key in loaded_image_keys + ] + sample[Fields.stats][StatsKeys.image_height] = [ + whs[key][1] for key in loaded_image_keys + ] + return sample + + def process(self, sample): + ws = sample[Fields.stats][StatsKeys.image_width] + hs = sample[Fields.stats][StatsKeys.image_height] + if len(ws) <= 0: + return True + keep_bools = np.array([ + self.min_width <= w <= self.max_width + and self.min_height <= h <= self.max_height + for w, h in zip(ws, hs) + ]) + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index b2f3a362f..366ee2366 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -28,6 +28,8 @@ class StatsKeys(object): # image aspect_ratios = 'aspect_ratios' + image_width = 'image_width' + image_height = 'image_height' class HashKeys(object): diff --git a/demos/overview_scan/app.py b/demos/overview_scan/app.py index 378b8f502..337e85598 100644 --- a/demos/overview_scan/app.py +++ b/demos/overview_scan/app.py @@ -89,7 +89,7 @@ |-----------------------------------|:------:|-------------------------------------------------| | Formatter | 7 | Discovers, loads, and canonicalizes source data | | Mapper | 21 | Edits and transforms samples | -| Filter | 17 | Filters out low-quality samples | +| Filter | 18 | Filters out low-quality samples | | Deduplicator | 3 | Detects and removes duplicate samples | | Selector | 2 | Selects top samples based on ranking | ''' @@ -142,6 +142,7 @@ | character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range | | flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold | | image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range | +| image_shape_filter | Image | - | Keeps samples contains images with widths and heights within specific ranges | | language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score | | maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range | | perplexity_filter | General | en, zh | Keeps samples with perplexity score below the specified threshold | diff --git a/docs/Operators.md b/docs/Operators.md index 78abeb495..9c4bb373f 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | | [ Mapper ]( #mapper ) | 21 | Edits and transforms samples | -| [ Filter ]( #filter ) | 17 | Filters out low-quality samples | +| [ Filter ]( #filter ) | 18 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 3 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 2 | Selects top samples based on ranking | @@ -77,6 +77,7 @@ All the specific operators are listed below, each featured with several capabili | character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range | | flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold | | image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range | +| image_shape_filter | Image | - | Keeps samples contains images with widths and heights within specific ranges | | language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score | | maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range | | perplexity_filter | General | en, zh | Keeps samples with perplexity score below the specified threshold | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index cf3421d94..5ca477fdc 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -10,7 +10,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | | [ Mapper ]( #mapper ) | 21 | 对数据样本进行编辑和转换 | -| [ Filter ]( #filter ) | 17 | 过滤低质量样本 | +| [ Filter ]( #filter ) | 18 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 3 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 2 | 基于排序选取高质量样本 | @@ -73,6 +73,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | character_repetition_filter | General | en, zh | 保留 char-level n-gram 重复比率在指定范围内的样本 | | flagged_words_filter | General | en, zh | 保留使标记字比率保持在指定阈值以下的样本 | | image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 | +| image_shape_filter | Image | - | 保留样本中包含的图片的形状(即宽和高)在指定范围内的样本 | | language_id_score_filter | General | en, zh | 保留特定语言的样本,通过预测的置信度得分来判断 | | maximum_line_length_filter | Code | en, zh | 保留最大行长度在指定范围内的样本 | | perplexity_filter | General | en, zh | 保留困惑度低于指定阈值的样本 | diff --git a/tests/ops/filter/test_image_aspect_ratio_filter.py b/tests/ops/filter/test_image_aspect_ratio_filter.py index 3d5ea6cf4..a328d934a 100644 --- a/tests/ops/filter/test_image_aspect_ratio_filter.py +++ b/tests/ops/filter/test_image_aspect_ratio_filter.py @@ -1,9 +1,7 @@ import os import unittest -import numpy as np -import PIL.Image -from datasets import Dataset, Image +from datasets import Dataset from data_juicer.ops.filter.image_aspect_ratio_filter import \ ImageAspectRatioFilter diff --git a/tests/ops/filter/test_image_shape_filter.py b/tests/ops/filter/test_image_shape_filter.py new file mode 100644 index 000000000..a00020a18 --- /dev/null +++ b/tests/ops/filter/test_image_shape_filter.py @@ -0,0 +1,127 @@ +import os +import unittest +import numpy as np +import PIL.Image + +from datasets import Dataset, Image + +from data_juicer.ops.filter.image_shape_filter import ImageShapeFilter +from data_juicer.utils.constant import Fields + + +class ImageShapeFilterTest(unittest.TestCase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), + '..', 'data') + img1_path = os.path.join(data_path, 'img1.png') + img2_path = os.path.join(data_path, 'img2.jpg') + img3_path = os.path.join(data_path, 'img3.jpg') + + def _run_image_shape_filter(self, + dataset: Dataset, + target_list, + op): + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=[op.image_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_filter1(self): + + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }, { + 'images': [self.img3_path] + }] + tgt_list = [{ + 'images': [self.img2_path] + }] + dataset = Dataset.from_list(ds_list) + op = ImageShapeFilter(min_width=400, + min_height=400) + self._run_image_shape_filter(dataset, tgt_list, op) + + def test_filter2(self): + + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }, { + 'images': [self.img3_path] + }] + tgt_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img3_path] + }] + dataset = Dataset.from_list(ds_list) + op = ImageShapeFilter(max_width=500, + max_height=500) + self._run_image_shape_filter(dataset, tgt_list, op) + + def test_filter3(self): + + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }, { + 'images': [self.img3_path] + }] + tgt_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }, { + 'images': [self.img3_path] + }] + dataset = Dataset.from_list(ds_list) + op = ImageShapeFilter() + self._run_image_shape_filter(dataset, tgt_list, op) + + def test_any(self): + + ds_list = [{ + 'images': [self.img1_path, self.img2_path] + }, { + 'images': [self.img2_path, self.img3_path] + }, { + 'images': [self.img1_path, self.img3_path] + }] + tgt_list = [{ + 'images': [self.img1_path, self.img2_path] + }, { + 'images': [self.img2_path, self.img3_path] + }] + dataset = Dataset.from_list(ds_list) + op = ImageShapeFilter(min_width=400, + min_height=400, + any_or_all='any') + self._run_image_shape_filter(dataset, tgt_list, op) + + def test_all(self): + + ds_list = [{ + 'images': [self.img1_path, self.img2_path] + }, { + 'images': [self.img2_path, self.img3_path] + }, { + 'images': [self.img1_path, self.img3_path] + }] + tgt_list = [] + dataset = Dataset.from_list(ds_list) + op = ImageShapeFilter(min_width=400, + min_height=400, + any_or_all='all') + self._run_image_shape_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main()