diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index 253b3552d..8d15b79f6 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -1,7 +1,8 @@ from . import (alphanumeric_filter, average_line_length_filter, character_repetition_filter, clip_similarity_filter, - flagged_words_filter, image_aspect_ratio_filter, - image_shape_filter, image_size_filter, language_id_score_filter, + face_area_filter, flagged_words_filter, + image_aspect_ratio_filter, image_shape_filter, + image_size_filter, language_id_score_filter, maximum_line_length_filter, perplexity_filter, special_characters_filter, specified_field_filter, specified_numeric_field_filter, stopwords_filter, suffix_filter, diff --git a/data_juicer/ops/filter/face_area_filter.py b/data_juicer/ops/filter/face_area_filter.py new file mode 100644 index 000000000..60b4eceb0 --- /dev/null +++ b/data_juicer/ops/filter/face_area_filter.py @@ -0,0 +1,146 @@ +import numpy as np +from jsonargparse.typing import ClosedUnitInterval + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import load_image, pil_to_opencv + +from ..base_op import OPERATORS, Filter +from ..op_fusion import LOADED_IMAGES + +OP_NAME = 'face_area_filter' + +with AvailabilityChecking(['dlib'], OP_NAME): + import dlib + + +def calculate_max_face_ratio(image, detections): + image_area = image.width * image.height + max_face_ratio = 0 + for detection in detections: + # 检查是否为有效的检测结果 + if detection != (-1, -1, -1, -1): + _, _, width, height = detection + face_area = width * height + face_ratio = face_area / image_area + max_face_ratio = max(max_face_ratio, face_ratio) + return max_face_ratio + + +@OPERATORS.register_module(OP_NAME) +@LOADED_IMAGES.register_module(OP_NAME) +class FaceAreaFilter(Filter): + """Filter to keep samples with face area ratio within a specific range. + """ + + def __init__(self, + min_ratio: ClosedUnitInterval = 0.0, + max_ratio: ClosedUnitInterval = 0.4, + any_or_all: str = 'any', + *args, + **kwargs): + """ + Initialization method. + + :param min_ratio: Min ratio for the largest face area in an image. + :param max_ratio: Max ratio for the largest face area in an image. + :param any_or_all: Keep this sample with 'any' or 'all' strategy of + all images. 'any': keep this sample if any images meet the + condition. 'all': keep this sample only if all images meet the + condition. + :param args: Extra positional arguments. + :param kwargs: Extra keyword arguments. + """ + + # Extract face detector arguments from kwargs + detector_keys = ['upsample_num_times'] + self.detector_kwargs = { + key: kwargs.pop(key) + for key in detector_keys if key in kwargs + } + + super().__init__(*args, **kwargs) + self.min_ratio = min_ratio + self.max_ratio = max_ratio + + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + + # Initialize face detector + self.detector = dlib.get_frontal_face_detector() + + def compute_stats(self, sample, context=False): + # check if it's computed already + if StatsKeys.face_ratios in sample[Fields.stats]: + return sample + + # there is no image in this sample, still default ratio 0.0 + if self.image_key not in sample or not sample[self.image_key]: + sample[Fields.stats][StatsKeys.face_ratios] = [0.0] + return sample + + # load images + loaded_image_keys = sample[self.image_key] + images = {} + for loaded_image_key in loaded_image_keys: + if context and loaded_image_key in sample[Fields.context]: + # load from context + images[loaded_image_key] = sample[ + Fields.context][loaded_image_key] + else: + if loaded_image_key not in images: + # avoid load the same images + image = load_image(loaded_image_key) + images[loaded_image_key] = image + if context: + # store the image data into context + sample[Fields.context][loaded_image_key] = image + + # check if faces detected already + if StatsKeys.face_detections not in sample[Fields.stats]: + face_detections = {} + for key, image in images.items(): + img = pil_to_opencv(image) + dets = self.detector(img, **self.detector_kwargs) + dets_formatted = [[ + det.left(), + det.top(), + det.width(), + det.height() + ] for det in dets] if dets else [[0, 0, 0, 0]] + face_detections[key] = dets_formatted + sample[Fields.stats][StatsKeys.face_detections] = [ + face_detections[key] for key in loaded_image_keys + ] + + max_face_ratios = [] + for key, dets in zip(loaded_image_keys, + sample[Fields.stats][StatsKeys.face_detections]): + img_area = images[key].width * images[key].height + # Calculate the max face ratio for the current image + max_face_ratios.append( + max([w * h / img_area for _, _, w, h in dets])) + sample[Fields.stats][StatsKeys.face_ratios] = max_face_ratios + + return sample + + def process(self, sample): + if self.image_key not in sample or not sample[self.image_key]: + return True + + face_ratios = sample[Fields.stats][StatsKeys.face_ratios] + if len(face_ratios) <= 0: + return True + + keep_bools = np.array([ + self.min_ratio <= face_ratio <= self.max_ratio + for face_ratio in face_ratios + ]) + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index e4ae2b233..cd636c510 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -31,6 +31,8 @@ class StatsKeys(object): image_width = 'image_width' image_height = 'image_height' image_sizes = 'image_sizes' + face_ratios = 'face_ratios' + face_detections = 'face_detections' # multimodal clip_image_text_similarity = 'clip_image_text_similarity' diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py index 817f298bd..766c8ba25 100644 --- a/data_juicer/utils/mm_utils.py +++ b/data_juicer/utils/mm_utils.py @@ -1,3 +1,4 @@ +import numpy as np from datasets import Audio, Image from data_juicer.utils.constant import DEFAULT_PREFIX @@ -34,6 +35,15 @@ def load_audio(path, sampling_rate=None): return (aud['array'], aud['sampling_rate']) +def pil_to_opencv(pil_image): + if pil_image.mode != 'RGB': + pil_image = pil_image.convert('RGB') + numpy_image = np.array(pil_image) + # RGB to BGR + opencv_image = numpy_image[:, :, ::-1] + return opencv_image + + def get_image_size(path, ): import os return os.path.getsize(path) diff --git a/docs/Operators.md b/docs/Operators.md index 7dbdf8ea6..a277f68b0 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | | [ Mapper ]( #mapper ) | 21 | Edits and transforms samples | -| [ Filter ]( #filter ) | 20 | Filters out low-quality samples | +| [ Filter ]( #filter ) | 21 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 4 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 2 | Selects top samples based on ranking | @@ -47,7 +47,7 @@ All the specific operators are listed below, each featured with several capabili | Operator | Domain | Lang | Description | |-----------------------------------------------------|--------------------|--------|----------------------------------------------------------------------------------------------------------------| -| chinese_convert_mapper | General | zh | Convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji (by [opencc](https://github.com/BYVoid/OpenCC)) | +| chinese_convert_mapper | General | zh | Converts Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji (by [opencc](https://github.com/BYVoid/OpenCC)) | | clean_copyright_mapper | Code | en, zh | Removes copyright notice at the beginning of code files (:warning: must contain the word *copyright*) | | clean_email_mapper | General | en, zh | Removes email information | | clean_html_mapper | General | en, zh | Removes HTML tags and returns plain text of all the nodes | @@ -55,8 +55,8 @@ All the specific operators are listed below, each featured with several capabili | clean_links_mapper | General, Code | en, zh | Removes links, such as those starting with http or ftp | | expand_macro_mapper | LaTeX | en, zh | Expands macros usually defined at the top of TeX documents | | fix_unicode_mapper | General | en, zh | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | -| nlpaug_en_mapper | General | en | Simply augment texts in English based on the `nlpaug` library | -| nlpcda_zh_mapper | General | zh | Simply augment texts in Chinese based on the `nlpcda` library | +| nlpaug_en_mapper | General | en | Simply augments texts in English based on the `nlpaug` library | +| nlpcda_zh_mapper | General | zh | Simply augments texts in Chinese based on the `nlpcda` library | | punctuation_normalization_mapper | General | en, zh | Normalizes various Unicode punctuations to their ASCII equivalents | | remove_bibliography_mapper | LaTeX | en, zh | Removes the bibliography of TeX documents | | remove_comments_mapper | LaTeX | en, zh | Removes the comments of TeX documents | @@ -77,11 +77,12 @@ All the specific operators are listed below, each featured with several capabili | alphanumeric_filter | General | en, zh | Keeps samples with alphanumeric ratio within the specified range | | average_line_length_filter | Code | en, zh | Keeps samples with average line length within the specified range | | character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range | -| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range | +| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range | +| face_area_filter | Image | - | Keeps samples contains images with face area ratios within the specified range | | flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold | -| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range | -| image_shape_filter | Image | - | Keeps samples contains images with widths and heights within specific ranges | -| image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within specific range | +| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within the specified range | +| image_shape_filter | Image | - | Keeps samples contains images with widths and heights within specified range | +| image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within specified range | | language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score | | maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range | | perplexity_filter | General | en, zh | Keeps samples with perplexity score below the specified threshold | @@ -98,12 +99,12 @@ All the specific operators are listed below, each featured with several capabili ## Deduplicator -| Operator | Domain | Lang | Description | -|-------------------------------|---------|--------|-------------------------------------------------------------| -| document_deduplicator | General | en, zh | Deduplicate samples at document-level by comparing MD5 hash | -| document_minhash_deduplicator | General | en, zh | Deduplicate samples at document-level using MinHashLSH | -| document_simhash_deduplicator | General | en, zh | Deduplicate samples at document-level using SimHash | -| image_deduplicator | Image | - | Deduplicate samples at document-level using exact matching of images between documents | +| Operator | Domain | Lang | Description | +|-------------------------------|---------|--------|--------------------------------------------------------------| +| document_deduplicator | General | en, zh | Deduplicates samples at document-level by comparing MD5 hash | +| document_minhash_deduplicator | General | en, zh | Deduplicates samples at document-level using MinHashLSH | +| document_simhash_deduplicator | General | en, zh | Deduplicates samples at document-level using SimHash | +| image_deduplicator | Image | - | Deduplicates samples at document-level using exact matching of images between documents | ## Selector diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index e7c566ef0..0784232d2 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -10,7 +10,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | | [ Mapper ]( #mapper ) | 21 | 对数据样本进行编辑和转换 | -| [ Filter ]( #filter ) | 20 | 过滤低质量样本 | +| [ Filter ]( #filter ) | 21 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 4 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 2 | 基于排序选取高质量样本 | @@ -75,6 +75,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | average_line_length_filter | Code | en, zh | 保留平均行长度在指定范围内的样本 | | character_repetition_filter | General | en, zh | 保留 char-level n-gram 重复比率在指定范围内的样本 | | clip_similarity_filter | Multimodal | - | 保留文本图像相似度在指定范围内的样本 | +| face_area_filter | Image | - | 保留样本中包含的图片的最大脸部区域在指定范围内的样本 | | flagged_words_filter | General | en, zh | 保留使标记字比率保持在指定阈值以下的样本 | | image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 | | image_shape_filter | Image | - | 保留样本中包含的图片的形状(即宽和高)在指定范围内的样本 | diff --git a/environments/science_requires.txt b/environments/science_requires.txt index f13d5b740..cf01ee5e7 100644 --- a/environments/science_requires.txt +++ b/environments/science_requires.txt @@ -12,3 +12,4 @@ transformers opencc==1.1.6 imagededup torch +opencv-python diff --git a/tests/ops/data/lena-face.jpg b/tests/ops/data/lena-face.jpg new file mode 100644 index 000000000..1e4fbaa01 Binary files /dev/null and b/tests/ops/data/lena-face.jpg differ diff --git a/tests/ops/data/lena.jpg b/tests/ops/data/lena.jpg new file mode 100644 index 000000000..f06aa74a5 Binary files /dev/null and b/tests/ops/data/lena.jpg differ diff --git a/tests/ops/filter/test_face_area_filter.py b/tests/ops/filter/test_face_area_filter.py new file mode 100644 index 000000000..67b6b147a --- /dev/null +++ b/tests/ops/filter/test_face_area_filter.py @@ -0,0 +1,148 @@ +import os +import unittest + +from datasets import Dataset +# from data_juicer.core.data import NestedDataset as Dataset + +from data_juicer.ops.filter.face_area_filter import FaceAreaFilter +from data_juicer.utils.constant import Fields + + +class FaceAreaFilterTest(unittest.TestCase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), + '..', 'data') + img1_path = os.path.join(data_path, 'cat.jpg') + img2_path = os.path.join(data_path, 'lena.jpg') + img3_path = os.path.join(data_path, 'lena-face.jpg') + + def _run_face_area_filter(self, + dataset: Dataset, target_list, + op, + num_proc=1): + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=num_proc) + dataset = dataset.filter(op.process, num_proc=num_proc) + dataset = dataset.remove_columns('__dj__stats__') + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_filter_small(self): + + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }, { + 'images': [self.img3_path] + }] + tgt_list = [{ + 'images': [self.img3_path] + }] + dataset = Dataset.from_list(ds_list) + op = FaceAreaFilter(min_ratio=0.4, max_ratio=1.0) + self._run_face_area_filter(dataset, tgt_list, op) + + def test_filter_large(self): + + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }, { + 'images': [self.img3_path] + }] + tgt_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }] + dataset = Dataset.from_list(ds_list) + op = FaceAreaFilter(min_ratio=0.0, max_ratio=0.4) + self._run_face_area_filter(dataset, tgt_list, op) + + def test_filter_multimodal(self): + + ds_list = [{ + 'text': 'a test sentence', 'images': [] + }, { + 'text': 'a test sentence', 'images': [self.img1_path] + }, { + 'text': 'a test sentence', 'images': [self.img2_path] + }, { + 'text': 'a test sentence', 'images': [self.img3_path] + }] + tgt_list = [{ + 'text': 'a test sentence', 'images': [] + }, { + 'text': 'a test sentence', 'images': [self.img1_path] + }, { + 'text': 'a test sentence', 'images': [self.img2_path] + }] + dataset = Dataset.from_list(ds_list) + op = FaceAreaFilter() + self._run_face_area_filter(dataset, tgt_list, op) + + def test_any(self): + + ds_list = [{ + 'images': [self.img1_path, self.img2_path] + }, { + 'images': [self.img2_path, self.img3_path] + }, { + 'images': [self.img1_path, self.img3_path] + }] + tgt_list = [{ + 'images': [self.img1_path, self.img2_path] + }, { + 'images': [self.img2_path, self.img3_path] + }, { + 'images': [self.img1_path, self.img3_path] + }] + dataset = Dataset.from_list(ds_list) + op = FaceAreaFilter(min_ratio=0.0, + max_ratio=0.4, + any_or_all='any') + self._run_face_area_filter(dataset, tgt_list, op) + + def test_all(self): + + ds_list = [{ + 'images': [self.img1_path, self.img2_path] + }, { + 'images': [self.img2_path, self.img3_path] + }, { + 'images': [self.img1_path, self.img3_path] + }] + tgt_list = [{ + 'images': [self.img1_path, self.img2_path] + }] + dataset = Dataset.from_list(ds_list) + op = FaceAreaFilter(min_ratio=0.0, + max_ratio=0.4, + any_or_all='all') + self._run_face_area_filter(dataset, tgt_list, op) + + def test_filter_multi_process(self): + + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }, { + 'images': [self.img3_path] + }] + tgt_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }] + dataset = Dataset.from_list(ds_list) + op = FaceAreaFilter() + self._run_face_area_filter(dataset, tgt_list, op, num_proc=3) + + +if __name__ == '__main__': + unittest.main()