diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py
index 253b3552d..8d15b79f6 100644
--- a/data_juicer/ops/filter/__init__.py
+++ b/data_juicer/ops/filter/__init__.py
@@ -1,7 +1,8 @@
from . import (alphanumeric_filter, average_line_length_filter,
character_repetition_filter, clip_similarity_filter,
- flagged_words_filter, image_aspect_ratio_filter,
- image_shape_filter, image_size_filter, language_id_score_filter,
+ face_area_filter, flagged_words_filter,
+ image_aspect_ratio_filter, image_shape_filter,
+ image_size_filter, language_id_score_filter,
maximum_line_length_filter, perplexity_filter,
special_characters_filter, specified_field_filter,
specified_numeric_field_filter, stopwords_filter, suffix_filter,
diff --git a/data_juicer/ops/filter/face_area_filter.py b/data_juicer/ops/filter/face_area_filter.py
new file mode 100644
index 000000000..60b4eceb0
--- /dev/null
+++ b/data_juicer/ops/filter/face_area_filter.py
@@ -0,0 +1,146 @@
+import numpy as np
+from jsonargparse.typing import ClosedUnitInterval
+
+from data_juicer.utils.availability_utils import AvailabilityChecking
+from data_juicer.utils.constant import Fields, StatsKeys
+from data_juicer.utils.mm_utils import load_image, pil_to_opencv
+
+from ..base_op import OPERATORS, Filter
+from ..op_fusion import LOADED_IMAGES
+
+OP_NAME = 'face_area_filter'
+
+with AvailabilityChecking(['dlib'], OP_NAME):
+ import dlib
+
+
+def calculate_max_face_ratio(image, detections):
+ image_area = image.width * image.height
+ max_face_ratio = 0
+ for detection in detections:
+ # 检查是否为有效的检测结果
+ if detection != (-1, -1, -1, -1):
+ _, _, width, height = detection
+ face_area = width * height
+ face_ratio = face_area / image_area
+ max_face_ratio = max(max_face_ratio, face_ratio)
+ return max_face_ratio
+
+
+@OPERATORS.register_module(OP_NAME)
+@LOADED_IMAGES.register_module(OP_NAME)
+class FaceAreaFilter(Filter):
+ """Filter to keep samples with face area ratio within a specific range.
+ """
+
+ def __init__(self,
+ min_ratio: ClosedUnitInterval = 0.0,
+ max_ratio: ClosedUnitInterval = 0.4,
+ any_or_all: str = 'any',
+ *args,
+ **kwargs):
+ """
+ Initialization method.
+
+ :param min_ratio: Min ratio for the largest face area in an image.
+ :param max_ratio: Max ratio for the largest face area in an image.
+ :param any_or_all: Keep this sample with 'any' or 'all' strategy of
+ all images. 'any': keep this sample if any images meet the
+ condition. 'all': keep this sample only if all images meet the
+ condition.
+ :param args: Extra positional arguments.
+ :param kwargs: Extra keyword arguments.
+ """
+
+ # Extract face detector arguments from kwargs
+ detector_keys = ['upsample_num_times']
+ self.detector_kwargs = {
+ key: kwargs.pop(key)
+ for key in detector_keys if key in kwargs
+ }
+
+ super().__init__(*args, **kwargs)
+ self.min_ratio = min_ratio
+ self.max_ratio = max_ratio
+
+ if any_or_all not in ['any', 'all']:
+ raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
+ f'Can only be one of ["any", "all"].')
+ self.any = (any_or_all == 'any')
+
+ # Initialize face detector
+ self.detector = dlib.get_frontal_face_detector()
+
+ def compute_stats(self, sample, context=False):
+ # check if it's computed already
+ if StatsKeys.face_ratios in sample[Fields.stats]:
+ return sample
+
+ # there is no image in this sample, still default ratio 0.0
+ if self.image_key not in sample or not sample[self.image_key]:
+ sample[Fields.stats][StatsKeys.face_ratios] = [0.0]
+ return sample
+
+ # load images
+ loaded_image_keys = sample[self.image_key]
+ images = {}
+ for loaded_image_key in loaded_image_keys:
+ if context and loaded_image_key in sample[Fields.context]:
+ # load from context
+ images[loaded_image_key] = sample[
+ Fields.context][loaded_image_key]
+ else:
+ if loaded_image_key not in images:
+ # avoid load the same images
+ image = load_image(loaded_image_key)
+ images[loaded_image_key] = image
+ if context:
+ # store the image data into context
+ sample[Fields.context][loaded_image_key] = image
+
+ # check if faces detected already
+ if StatsKeys.face_detections not in sample[Fields.stats]:
+ face_detections = {}
+ for key, image in images.items():
+ img = pil_to_opencv(image)
+ dets = self.detector(img, **self.detector_kwargs)
+ dets_formatted = [[
+ det.left(),
+ det.top(),
+ det.width(),
+ det.height()
+ ] for det in dets] if dets else [[0, 0, 0, 0]]
+ face_detections[key] = dets_formatted
+ sample[Fields.stats][StatsKeys.face_detections] = [
+ face_detections[key] for key in loaded_image_keys
+ ]
+
+ max_face_ratios = []
+ for key, dets in zip(loaded_image_keys,
+ sample[Fields.stats][StatsKeys.face_detections]):
+ img_area = images[key].width * images[key].height
+ # Calculate the max face ratio for the current image
+ max_face_ratios.append(
+ max([w * h / img_area for _, _, w, h in dets]))
+ sample[Fields.stats][StatsKeys.face_ratios] = max_face_ratios
+
+ return sample
+
+ def process(self, sample):
+ if self.image_key not in sample or not sample[self.image_key]:
+ return True
+
+ face_ratios = sample[Fields.stats][StatsKeys.face_ratios]
+ if len(face_ratios) <= 0:
+ return True
+
+ keep_bools = np.array([
+ self.min_ratio <= face_ratio <= self.max_ratio
+ for face_ratio in face_ratios
+ ])
+
+ # different strategies
+ if self.any:
+ return keep_bools.any()
+ else:
+ return keep_bools.all()
diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py
index e4ae2b233..cd636c510 100644
--- a/data_juicer/utils/constant.py
+++ b/data_juicer/utils/constant.py
@@ -31,6 +31,8 @@ class StatsKeys(object):
image_width = 'image_width'
image_height = 'image_height'
image_sizes = 'image_sizes'
+ face_ratios = 'face_ratios'
+ face_detections = 'face_detections'
# multimodal
clip_image_text_similarity = 'clip_image_text_similarity'
diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py
index 817f298bd..766c8ba25 100644
--- a/data_juicer/utils/mm_utils.py
+++ b/data_juicer/utils/mm_utils.py
@@ -1,3 +1,4 @@
+import numpy as np
from datasets import Audio, Image
from data_juicer.utils.constant import DEFAULT_PREFIX
@@ -34,6 +35,15 @@ def load_audio(path, sampling_rate=None):
return (aud['array'], aud['sampling_rate'])
+def pil_to_opencv(pil_image):
+ if pil_image.mode != 'RGB':
+ pil_image = pil_image.convert('RGB')
+ numpy_image = np.array(pil_image)
+ # RGB to BGR
+ opencv_image = numpy_image[:, :, ::-1]
+ return opencv_image
+
+
def get_image_size(path, ):
import os
return os.path.getsize(path)
diff --git a/docs/Operators.md b/docs/Operators.md
index 7dbdf8ea6..a277f68b0 100644
--- a/docs/Operators.md
+++ b/docs/Operators.md
@@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types.
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 21 | Edits and transforms samples |
-| [ Filter ]( #filter ) | 20 | Filters out low-quality samples |
+| [ Filter ]( #filter ) | 21 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 4 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 2 | Selects top samples based on ranking |
@@ -47,7 +47,7 @@ All the specific operators are listed below, each featured with several capabili
| Operator | Domain | Lang | Description |
|-----------------------------------------------------|--------------------|--------|----------------------------------------------------------------------------------------------------------------|
-| chinese_convert_mapper | General | zh | Convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji (by [opencc](https://github.com/BYVoid/OpenCC)) |
+| chinese_convert_mapper | General | zh | Converts Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji (by [opencc](https://github.com/BYVoid/OpenCC)) |
| clean_copyright_mapper | Code | en, zh | Removes copyright notice at the beginning of code files (:warning: must contain the word *copyright*) |
| clean_email_mapper | General | en, zh | Removes email information |
| clean_html_mapper | General | en, zh | Removes HTML tags and returns plain text of all the nodes |
@@ -55,8 +55,8 @@ All the specific operators are listed below, each featured with several capabili
| clean_links_mapper | General, Code | en, zh | Removes links, such as those starting with http or ftp |
| expand_macro_mapper | LaTeX | en, zh | Expands macros usually defined at the top of TeX documents |
| fix_unicode_mapper | General | en, zh | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) |
-| nlpaug_en_mapper | General | en | Simply augment texts in English based on the `nlpaug` library |
-| nlpcda_zh_mapper | General | zh | Simply augment texts in Chinese based on the `nlpcda` library |
+| nlpaug_en_mapper | General | en | Simply augments texts in English based on the `nlpaug` library |
+| nlpcda_zh_mapper | General | zh | Simply augments texts in Chinese based on the `nlpcda` library |
| punctuation_normalization_mapper | General | en, zh | Normalizes various Unicode punctuations to their ASCII equivalents |
| remove_bibliography_mapper | LaTeX | en, zh | Removes the bibliography of TeX documents |
| remove_comments_mapper | LaTeX | en, zh | Removes the comments of TeX documents |
@@ -77,11 +77,12 @@ All the specific operators are listed below, each featured with several capabili
| alphanumeric_filter | General | en, zh | Keeps samples with alphanumeric ratio within the specified range |
| average_line_length_filter | Code | en, zh | Keeps samples with average line length within the specified range |
| character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range |
-| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range |
+| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range |
+| face_area_filter | Image | - | Keeps samples contains images with face area ratios within the specified range |
| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold |
-| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range |
-| image_shape_filter | Image | - | Keeps samples contains images with widths and heights within specific ranges |
-| image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within specific range |
+| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within the specified range |
+| image_shape_filter | Image | - | Keeps samples contains images with widths and heights within specified range |
+| image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within specified range |
| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score |
| maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range |
| perplexity_filter | General | en, zh | Keeps samples with perplexity score below the specified threshold |
@@ -98,12 +99,12 @@ All the specific operators are listed below, each featured with several capabili
## Deduplicator
-| Operator | Domain | Lang | Description |
-|-------------------------------|---------|--------|-------------------------------------------------------------|
-| document_deduplicator | General | en, zh | Deduplicate samples at document-level by comparing MD5 hash |
-| document_minhash_deduplicator | General | en, zh | Deduplicate samples at document-level using MinHashLSH |
-| document_simhash_deduplicator | General | en, zh | Deduplicate samples at document-level using SimHash |
-| image_deduplicator | Image | - | Deduplicate samples at document-level using exact matching of images between documents |
+| Operator | Domain | Lang | Description |
+|-------------------------------|---------|--------|--------------------------------------------------------------|
+| document_deduplicator | General | en, zh | Deduplicates samples at document-level by comparing MD5 hash |
+| document_minhash_deduplicator | General | en, zh | Deduplicates samples at document-level using MinHashLSH |
+| document_simhash_deduplicator | General | en, zh | Deduplicates samples at document-level using SimHash |
+| image_deduplicator | Image | - | Deduplicates samples at document-level using exact matching of images between documents |
## Selector
diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md
index e7c566ef0..0784232d2 100644
--- a/docs/Operators_ZH.md
+++ b/docs/Operators_ZH.md
@@ -10,7 +10,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 |
| [ Mapper ]( #mapper ) | 21 | 对数据样本进行编辑和转换 |
-| [ Filter ]( #filter ) | 20 | 过滤低质量样本 |
+| [ Filter ]( #filter ) | 21 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 4 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 2 | 基于排序选取高质量样本 |
@@ -75,6 +75,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| average_line_length_filter | Code | en, zh | 保留平均行长度在指定范围内的样本 |
| character_repetition_filter | General | en, zh | 保留 char-level n-gram 重复比率在指定范围内的样本 |
| clip_similarity_filter | Multimodal | - | 保留文本图像相似度在指定范围内的样本 |
+| face_area_filter | Image | - | 保留样本中包含的图片的最大脸部区域在指定范围内的样本 |
| flagged_words_filter | General | en, zh | 保留使标记字比率保持在指定阈值以下的样本 |
| image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 |
| image_shape_filter | Image | - | 保留样本中包含的图片的形状(即宽和高)在指定范围内的样本 |
diff --git a/environments/science_requires.txt b/environments/science_requires.txt
index f13d5b740..cf01ee5e7 100644
--- a/environments/science_requires.txt
+++ b/environments/science_requires.txt
@@ -12,3 +12,4 @@ transformers
opencc==1.1.6
imagededup
torch
+opencv-python
diff --git a/tests/ops/data/lena-face.jpg b/tests/ops/data/lena-face.jpg
new file mode 100644
index 000000000..1e4fbaa01
Binary files /dev/null and b/tests/ops/data/lena-face.jpg differ
diff --git a/tests/ops/data/lena.jpg b/tests/ops/data/lena.jpg
new file mode 100644
index 000000000..f06aa74a5
Binary files /dev/null and b/tests/ops/data/lena.jpg differ
diff --git a/tests/ops/filter/test_face_area_filter.py b/tests/ops/filter/test_face_area_filter.py
new file mode 100644
index 000000000..67b6b147a
--- /dev/null
+++ b/tests/ops/filter/test_face_area_filter.py
@@ -0,0 +1,148 @@
+import os
+import unittest
+
+from datasets import Dataset
+# from data_juicer.core.data import NestedDataset as Dataset
+
+from data_juicer.ops.filter.face_area_filter import FaceAreaFilter
+from data_juicer.utils.constant import Fields
+
+
+class FaceAreaFilterTest(unittest.TestCase):
+
+ data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
+ '..', 'data')
+ img1_path = os.path.join(data_path, 'cat.jpg')
+ img2_path = os.path.join(data_path, 'lena.jpg')
+ img3_path = os.path.join(data_path, 'lena-face.jpg')
+
+ def _run_face_area_filter(self,
+ dataset: Dataset, target_list,
+ op,
+ num_proc=1):
+ if Fields.stats not in dataset.features:
+ dataset = dataset.add_column(name=Fields.stats,
+ column=[{}] * dataset.num_rows)
+ dataset = dataset.map(op.compute_stats, num_proc=num_proc)
+ dataset = dataset.filter(op.process, num_proc=num_proc)
+ dataset = dataset.remove_columns('__dj__stats__')
+ res_list = dataset.to_list()
+ self.assertEqual(res_list, target_list)
+
+ def test_filter_small(self):
+
+ ds_list = [{
+ 'images': [self.img1_path]
+ }, {
+ 'images': [self.img2_path]
+ }, {
+ 'images': [self.img3_path]
+ }]
+ tgt_list = [{
+ 'images': [self.img3_path]
+ }]
+ dataset = Dataset.from_list(ds_list)
+ op = FaceAreaFilter(min_ratio=0.4, max_ratio=1.0)
+ self._run_face_area_filter(dataset, tgt_list, op)
+
+ def test_filter_large(self):
+
+ ds_list = [{
+ 'images': [self.img1_path]
+ }, {
+ 'images': [self.img2_path]
+ }, {
+ 'images': [self.img3_path]
+ }]
+ tgt_list = [{
+ 'images': [self.img1_path]
+ }, {
+ 'images': [self.img2_path]
+ }]
+ dataset = Dataset.from_list(ds_list)
+ op = FaceAreaFilter(min_ratio=0.0, max_ratio=0.4)
+ self._run_face_area_filter(dataset, tgt_list, op)
+
+ def test_filter_multimodal(self):
+
+ ds_list = [{
+ 'text': 'a test sentence', 'images': []
+ }, {
+ 'text': 'a test sentence', 'images': [self.img1_path]
+ }, {
+ 'text': 'a test sentence', 'images': [self.img2_path]
+ }, {
+ 'text': 'a test sentence', 'images': [self.img3_path]
+ }]
+ tgt_list = [{
+ 'text': 'a test sentence', 'images': []
+ }, {
+ 'text': 'a test sentence', 'images': [self.img1_path]
+ }, {
+ 'text': 'a test sentence', 'images': [self.img2_path]
+ }]
+ dataset = Dataset.from_list(ds_list)
+ op = FaceAreaFilter()
+ self._run_face_area_filter(dataset, tgt_list, op)
+
+ def test_any(self):
+
+ ds_list = [{
+ 'images': [self.img1_path, self.img2_path]
+ }, {
+ 'images': [self.img2_path, self.img3_path]
+ }, {
+ 'images': [self.img1_path, self.img3_path]
+ }]
+ tgt_list = [{
+ 'images': [self.img1_path, self.img2_path]
+ }, {
+ 'images': [self.img2_path, self.img3_path]
+ }, {
+ 'images': [self.img1_path, self.img3_path]
+ }]
+ dataset = Dataset.from_list(ds_list)
+ op = FaceAreaFilter(min_ratio=0.0,
+ max_ratio=0.4,
+ any_or_all='any')
+ self._run_face_area_filter(dataset, tgt_list, op)
+
+ def test_all(self):
+
+ ds_list = [{
+ 'images': [self.img1_path, self.img2_path]
+ }, {
+ 'images': [self.img2_path, self.img3_path]
+ }, {
+ 'images': [self.img1_path, self.img3_path]
+ }]
+ tgt_list = [{
+ 'images': [self.img1_path, self.img2_path]
+ }]
+ dataset = Dataset.from_list(ds_list)
+ op = FaceAreaFilter(min_ratio=0.0,
+ max_ratio=0.4,
+ any_or_all='all')
+ self._run_face_area_filter(dataset, tgt_list, op)
+
+ def test_filter_multi_process(self):
+
+ ds_list = [{
+ 'images': [self.img1_path]
+ }, {
+ 'images': [self.img2_path]
+ }, {
+ 'images': [self.img3_path]
+ }]
+ tgt_list = [{
+ 'images': [self.img1_path]
+ }, {
+ 'images': [self.img2_path]
+ }]
+ dataset = Dataset.from_list(ds_list)
+ op = FaceAreaFilter()
+ self._run_face_area_filter(dataset, tgt_list, op, num_proc=3)
+
+
+if __name__ == '__main__':
+ unittest.main()