Skip to content

Commit

Permalink
add face_area_filter
Browse files Browse the repository at this point in the history
  • Loading branch information
null authored and null committed Nov 28, 2023
1 parent 754720a commit 61a9f73
Show file tree
Hide file tree
Showing 10 changed files with 339 additions and 17 deletions.
5 changes: 3 additions & 2 deletions data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from . import (alphanumeric_filter, average_line_length_filter,
character_repetition_filter, clip_similarity_filter,
flagged_words_filter, image_aspect_ratio_filter,
image_shape_filter, image_size_filter, language_id_score_filter,
face_area_filter, flagged_words_filter,
image_aspect_ratio_filter, image_shape_filter,
image_size_filter, language_id_score_filter,
maximum_line_length_filter, perplexity_filter,
special_characters_filter, specified_field_filter,
specified_numeric_field_filter, stopwords_filter, suffix_filter,
Expand Down
159 changes: 159 additions & 0 deletions data_juicer/ops/filter/face_area_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import numpy as np
from jsonargparse.typing import ClosedUnitInterval

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import load_image, pil_to_opencv

from ..base_op import OPERATORS, Filter
from ..op_fusion import LOADED_IMAGES

OP_NAME = 'face_area_filter'

with AvailabilityChecking(['cv2'], OP_NAME):
import cv2 # noqa: F401

# avoid hanging of some functions in multiprocessing
cv2.setNumThreads(1)


class LazyCascadeClassifier:

def __init__(self, file_path):
self.file_path = file_path

def __getstate__(self):
# Only the file path is pickled
return self.file_path

def __setstate__(self, state):
self.file_path = state

def get_classifier(self):
# Load the classifier when needed, not when pickling
return cv2.CascadeClassifier(cv2.data.haarcascades + self.file_path)


@OPERATORS.register_module('face_area_filter')
@LOADED_IMAGES.register_module('face_area_filter')
class FaceAreaFilter(Filter):
"""Filter to keep samples with face area ratio within a specific range.
"""

def __init__(self,
min_ratio: ClosedUnitInterval = 0.0,
max_ratio: ClosedUnitInterval = 0.4,
any_or_all: str = 'any',
*args,
**kwargs):
"""
Initialization method.
:param min_ratio: Min ratio for the largest face area in an image.
:param max_ratio: Max ratio for the largest face area in an image.
:param any_or_all: Keep this sample with 'any' or 'all' strategy of
all images. 'any': keep this sample if any images meet the
condition. 'all': keep this sample only if all images meet the
condition.
:param args: Extra positional arguments.
:param kwargs: Extra keyword arguments.
"""
super().__init__(*args, **kwargs)
self.min_ratio = min_ratio
self.max_ratio = max_ratio

if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

# Extract face detector arguments from kwargs
detector_keys = [
'scaleFactor', 'minNeighbors', 'flags', 'minSize', 'maxSize'
]
self.detector_kwargs = {
key: kwargs.pop(key)
for key in detector_keys if key in kwargs
}
# Initialize face detector
# prepare_detector()
# self.classifier_conf = 'haarcascade_frontalface_default.xml'
self.pickable_detector = LazyCascadeClassifier(
'haarcascade_frontalface_default.xml')

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.face_ratios in sample[Fields.stats]:
return sample

# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
sample[Fields.stats][StatsKeys.face_ratios] = np.array(
[], dtype=np.float64)
return sample

# load images
loaded_image_keys = sample[self.image_key]
images = {}
for loaded_image_key in loaded_image_keys:
if context and loaded_image_key in sample[Fields.context]:
# load from context
images[loaded_image_key] = sample[
Fields.context][loaded_image_key]
else:
if loaded_image_key not in images:
# avoid load the same images
image = load_image(loaded_image_key)
images[loaded_image_key] = image
if context:
# store the image data into context
sample[Fields.context][loaded_image_key] = image

# check if faces detected already
if StatsKeys.face_detections not in sample[Fields.stats]:
detector = self.pickable_detector.get_classifier()
face_detections = {}
for key, image in images.items():
# convert into grayscale opencv format
opencv_img = pil_to_opencv(image, grayscale=True)
# detect faces
detected_faces = detector.detectMultiScale(
opencv_img, **self.detector_kwargs)
# the rectangles may be partially outside the original image
# right-closed and right-open
face_detections[key] = [[
max(0, min(x, image.width - 1)),
max(0, min(y, image.height - 1)),
max(1, min(x + w, image.width)),
max(1, min(y + h, image.height))
] for (x, y, w, h) in detected_faces]

sample[Fields.stats][StatsKeys.face_detections] = [
face_detections[key] for key in loaded_image_keys
]

sample[Fields.stats][StatsKeys.face_ratios] = [
max([((x2 - x1) * (y2 - y1)) /
(images[key].width * images[key].height)
for x1, y1, x2, y2 in dets],
default=0) for key, dets in zip(
loaded_image_keys, sample[Fields.stats][
StatsKeys.face_detections])
]
return sample

def process(self, sample):
face_ratios = sample[Fields.stats][StatsKeys.face_ratios]
if len(face_ratios) <= 0:
return True

keep_bools = np.array([
self.min_ratio <= face_ratio <= self.max_ratio
for face_ratio in face_ratios
])

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
2 changes: 2 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class StatsKeys(object):
image_width = 'image_width'
image_height = 'image_height'
image_sizes = 'image_sizes'
face_ratios = 'face_ratios'
face_detections = 'face_detections'

# multimodal
clip_image_text_similarity = 'clip_image_text_similarity'
Expand Down
13 changes: 13 additions & 0 deletions data_juicer/utils/mm_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from datasets import Audio, Image

from data_juicer.utils.constant import DEFAULT_PREFIX
Expand Down Expand Up @@ -34,6 +35,18 @@ def load_audio(path, sampling_rate=None):
return (aud['array'], aud['sampling_rate'])


def pil_to_opencv(pil_image, grayscale=False):
mode = 'L' if grayscale else 'RGB'
if pil_image.mode != mode:
pil_image = pil_image.convert(mode)
numpy_image = np.array(pil_image)
# Note: cv2.cvtColor with num_proc > 1 can cause a deadlock,
# manual RGB to BGR
if mode == 'RGB':
numpy_image = numpy_image[:, :, -1]
return numpy_image


def get_image_size(path, ):
import os
return os.path.getsize(path)
Expand Down
29 changes: 15 additions & 14 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types.
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 21 | Edits and transforms samples |
| [ Filter ]( #filter ) | 20 | Filters out low-quality samples |
| [ Filter ]( #filter ) | 21 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 4 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 2 | Selects top samples based on ranking |

Expand Down Expand Up @@ -47,16 +47,16 @@ All the specific operators are listed below, each featured with several capabili

| Operator | Domain | Lang | Description |
|-----------------------------------------------------|--------------------|--------|----------------------------------------------------------------------------------------------------------------|
| chinese_convert_mapper | General | zh | Convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji (by [opencc](https://github.com/BYVoid/OpenCC)) |
| chinese_convert_mapper | General | zh | Converts Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji (by [opencc](https://github.com/BYVoid/OpenCC)) |
| clean_copyright_mapper | Code | en, zh | Removes copyright notice at the beginning of code files (:warning: must contain the word *copyright*) |
| clean_email_mapper | General | en, zh | Removes email information |
| clean_html_mapper | General | en, zh | Removes HTML tags and returns plain text of all the nodes |
| clean_ip_mapper | General | en, zh | Removes IP addresses |
| clean_links_mapper | General, Code | en, zh | Removes links, such as those starting with http or ftp |
| expand_macro_mapper | LaTeX | en, zh | Expands macros usually defined at the top of TeX documents |
| fix_unicode_mapper | General | en, zh | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) |
| nlpaug_en_mapper | General | en | Simply augment texts in English based on the `nlpaug` library |
| nlpcda_zh_mapper | General | zh | Simply augment texts in Chinese based on the `nlpcda` library |
| nlpaug_en_mapper | General | en | Simply augments texts in English based on the `nlpaug` library |
| nlpcda_zh_mapper | General | zh | Simply augments texts in Chinese based on the `nlpcda` library |
| punctuation_normalization_mapper | General | en, zh | Normalizes various Unicode punctuations to their ASCII equivalents |
| remove_bibliography_mapper | LaTeX | en, zh | Removes the bibliography of TeX documents |
| remove_comments_mapper | LaTeX | en, zh | Removes the comments of TeX documents |
Expand All @@ -77,11 +77,12 @@ All the specific operators are listed below, each featured with several capabili
| alphanumeric_filter | General | en, zh | Keeps samples with alphanumeric ratio within the specified range |
| average_line_length_filter | Code | en, zh | Keeps samples with average line length within the specified range |
| character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range |
| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range |
| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range |
| face_area_filter | Image | - | Keeps samples contains images with face area ratios within the specified range |
| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold |
| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range |
| image_shape_filter | Image | - | Keeps samples contains images with widths and heights within specific ranges |
| image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within specific range |
| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within the specified range |
| image_shape_filter | Image | - | Keeps samples contains images with widths and heights within specified range |
| image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within specified range |
| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score |
| maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range |
| perplexity_filter | General | en, zh | Keeps samples with perplexity score below the specified threshold |
Expand All @@ -98,12 +99,12 @@ All the specific operators are listed below, each featured with several capabili

## Deduplicator <a name="deduplicator"/>

| Operator | Domain | Lang | Description |
|-------------------------------|---------|--------|-------------------------------------------------------------|
| document_deduplicator | General | en, zh | Deduplicate samples at document-level by comparing MD5 hash |
| document_minhash_deduplicator | General | en, zh | Deduplicate samples at document-level using MinHashLSH |
| document_simhash_deduplicator | General | en, zh | Deduplicate samples at document-level using SimHash |
| image_deduplicator | Image | - | Deduplicate samples at document-level using exact matching of images between documents |
| Operator | Domain | Lang | Description |
|-------------------------------|---------|--------|--------------------------------------------------------------|
| document_deduplicator | General | en, zh | Deduplicates samples at document-level by comparing MD5 hash |
| document_minhash_deduplicator | General | en, zh | Deduplicates samples at document-level using MinHashLSH |
| document_simhash_deduplicator | General | en, zh | Deduplicates samples at document-level using SimHash |
| image_deduplicator | Image | - | Deduplicates samples at document-level using exact matching of images between documents |


## Selector <a name="selector"/>
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 |
| [ Mapper ]( #mapper ) | 21 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 20 | 过滤低质量样本 |
| [ Filter ]( #filter ) | 21 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 4 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 2 | 基于排序选取高质量样本 |

Expand Down Expand Up @@ -75,6 +75,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| average_line_length_filter | Code | en, zh | 保留平均行长度在指定范围内的样本 |
| character_repetition_filter | General | en, zh | 保留 char-level n-gram 重复比率在指定范围内的样本 |
| clip_similarity_filter | Multimodal | - | 保留文本图像相似度在指定范围内的样本 |
| face_area_filter | Image | - | 保留样本中包含的图片的最大脸部区域在指定范围内的样本 |
| flagged_words_filter | General | en, zh | 保留使标记字比率保持在指定阈值以下的样本 |
| image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 |
| image_shape_filter | Image | - | 保留样本中包含的图片的形状(即宽和高)在指定范围内的样本 |
Expand Down
1 change: 1 addition & 0 deletions environments/science_requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ transformers
opencc==1.1.6
imagededup
torch
opencv-python
Binary file added tests/ops/data/lena-face.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/ops/data/lena.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 61a9f73

Please sign in to comment.