-
Notifications
You must be signed in to change notification settings - Fork 191
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
256 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
from . import (alphanumeric_filter, average_line_length_filter, | ||
character_repetition_filter, flagged_words_filter, | ||
image_aspect_ratio_filter, language_id_score_filter, | ||
maximum_line_length_filter, perplexity_filter, | ||
special_characters_filter, specified_field_filter, | ||
specified_numeric_field_filter, stopwords_filter, suffix_filter, | ||
text_length_filter, token_num_filter, word_num_filter, | ||
word_repetition_filter) | ||
image_aspect_ratio_filter, image_shape_filter, | ||
language_id_score_filter, maximum_line_length_filter, | ||
perplexity_filter, special_characters_filter, | ||
specified_field_filter, specified_numeric_field_filter, | ||
stopwords_filter, suffix_filter, text_length_filter, | ||
token_num_filter, word_num_filter, word_repetition_filter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import sys | ||
|
||
import numpy as np | ||
from jsonargparse.typing import PositiveInt | ||
|
||
from data_juicer.utils.constant import Fields, StatsKeys | ||
from data_juicer.utils.mm_utils import load_image | ||
|
||
from ..base_op import OPERATORS, Filter | ||
from ..op_fusion import LOADED_IMAGES | ||
|
||
|
||
@OPERATORS.register_module('image_shape_filter') | ||
@LOADED_IMAGES.register_module('image_shape_filter') | ||
class ImageShapeFilter(Filter): | ||
"""Filter to keep samples with image shape (w, h) within specific ranges. | ||
""" | ||
|
||
def __init__(self, | ||
min_width: PositiveInt = 1, | ||
max_width: PositiveInt = sys.maxsize, | ||
min_height: PositiveInt = 1, | ||
max_height: PositiveInt = sys.maxsize, | ||
any_or_all: str = 'any', | ||
*args, | ||
**kwargs): | ||
""" | ||
Initialization method. | ||
:param min_width: The min width to keep samples. | ||
:param max_width: The max width to keep samples. | ||
:param min_height: The min height to keep samples. | ||
:param max_height: The max height to keep samples. | ||
:param any_or_all: keep this sample with 'any' or 'all' strategy of | ||
all images. 'any': keep this sample if any images meet the | ||
condition. 'all': keep this sample only if all images meet the | ||
condition. | ||
:param args: extra args | ||
:param kwargs: extra args | ||
""" | ||
super().__init__(*args, **kwargs) | ||
self.min_width = min_width | ||
self.max_width = max_width | ||
self.min_height = min_height | ||
self.max_height = max_height | ||
if any_or_all not in ['any', 'all']: | ||
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' | ||
f'Can only be one of ["any", "all"].') | ||
self.any = (any_or_all == 'any') | ||
|
||
def compute_stats(self, sample, context=False): | ||
# check if it's computed already | ||
if StatsKeys.image_width in sample[Fields.stats] \ | ||
and StatsKeys.image_height in sample[Fields.stats]: | ||
return sample | ||
|
||
# there is no image in this sample | ||
if self.image_key not in sample or not sample[self.image_key]: | ||
sample[Fields.stats][StatsKeys.image_width] = np.array( | ||
[], dtype=np.int64) | ||
sample[Fields.stats][StatsKeys.image_height] = np.array( | ||
[], dtype=np.int64) | ||
return sample | ||
|
||
# load images | ||
loaded_image_keys = sample[self.image_key] | ||
images = {} | ||
for loaded_image_key in loaded_image_keys: | ||
if context and loaded_image_key in sample[Fields.context]: | ||
# load from context | ||
images[loaded_image_key] = sample[ | ||
Fields.context][loaded_image_key] | ||
else: | ||
if loaded_image_key not in images: | ||
# avoid load the same images | ||
image = load_image(loaded_image_key) | ||
images[loaded_image_key] = image | ||
if context: | ||
# store the image data into context | ||
sample[Fields.context][loaded_image_key] = image | ||
|
||
# get width and height for each image | ||
whs = {key: (images[key].width, images[key].height) for key in images} | ||
sample[Fields.stats][StatsKeys.image_width] = [ | ||
whs[key][0] for key in loaded_image_keys | ||
] | ||
sample[Fields.stats][StatsKeys.image_height] = [ | ||
whs[key][1] for key in loaded_image_keys | ||
] | ||
return sample | ||
|
||
def process(self, sample): | ||
ws = sample[Fields.stats][StatsKeys.image_width] | ||
hs = sample[Fields.stats][StatsKeys.image_height] | ||
if len(ws) <= 0: | ||
return True | ||
keep_bools = np.array([ | ||
self.min_width <= w <= self.max_width | ||
and self.min_height <= h <= self.max_height | ||
for w, h in zip(ws, hs) | ||
]) | ||
|
||
# different strategies | ||
if self.any: | ||
return keep_bools.any() | ||
else: | ||
return keep_bools.all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import os | ||
import unittest | ||
import numpy as np | ||
import PIL.Image | ||
|
||
from datasets import Dataset, Image | ||
|
||
from data_juicer.ops.filter.image_shape_filter import ImageShapeFilter | ||
from data_juicer.utils.constant import Fields | ||
|
||
|
||
class ImageShapeFilterTest(unittest.TestCase): | ||
|
||
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), | ||
'..', 'data') | ||
img1_path = os.path.join(data_path, 'img1.png') | ||
img2_path = os.path.join(data_path, 'img2.jpg') | ||
img3_path = os.path.join(data_path, 'img3.jpg') | ||
|
||
def _run_image_shape_filter(self, | ||
dataset: Dataset, | ||
target_list, | ||
op): | ||
if Fields.stats not in dataset.features: | ||
dataset = dataset.add_column(name=Fields.stats, | ||
column=[{}] * dataset.num_rows) | ||
dataset = dataset.map(op.compute_stats) | ||
dataset = dataset.filter(op.process) | ||
dataset = dataset.select_columns(column_names=[op.image_key]) | ||
res_list = dataset.to_list() | ||
self.assertEqual(res_list, target_list) | ||
|
||
def test_filter1(self): | ||
|
||
ds_list = [{ | ||
'images': [self.img1_path] | ||
}, { | ||
'images': [self.img2_path] | ||
}, { | ||
'images': [self.img3_path] | ||
}] | ||
tgt_list = [{ | ||
'images': [self.img2_path] | ||
}] | ||
dataset = Dataset.from_list(ds_list) | ||
op = ImageShapeFilter(min_width=400, | ||
min_height=400) | ||
self._run_image_shape_filter(dataset, tgt_list, op) | ||
|
||
def test_filter2(self): | ||
|
||
ds_list = [{ | ||
'images': [self.img1_path] | ||
}, { | ||
'images': [self.img2_path] | ||
}, { | ||
'images': [self.img3_path] | ||
}] | ||
tgt_list = [{ | ||
'images': [self.img1_path] | ||
}, { | ||
'images': [self.img3_path] | ||
}] | ||
dataset = Dataset.from_list(ds_list) | ||
op = ImageShapeFilter(max_width=500, | ||
max_height=500) | ||
self._run_image_shape_filter(dataset, tgt_list, op) | ||
|
||
def test_filter3(self): | ||
|
||
ds_list = [{ | ||
'images': [self.img1_path] | ||
}, { | ||
'images': [self.img2_path] | ||
}, { | ||
'images': [self.img3_path] | ||
}] | ||
tgt_list = [{ | ||
'images': [self.img1_path] | ||
}, { | ||
'images': [self.img2_path] | ||
}, { | ||
'images': [self.img3_path] | ||
}] | ||
dataset = Dataset.from_list(ds_list) | ||
op = ImageShapeFilter() | ||
self._run_image_shape_filter(dataset, tgt_list, op) | ||
|
||
def test_any(self): | ||
|
||
ds_list = [{ | ||
'images': [self.img1_path, self.img2_path] | ||
}, { | ||
'images': [self.img2_path, self.img3_path] | ||
}, { | ||
'images': [self.img1_path, self.img3_path] | ||
}] | ||
tgt_list = [{ | ||
'images': [self.img1_path, self.img2_path] | ||
}, { | ||
'images': [self.img2_path, self.img3_path] | ||
}] | ||
dataset = Dataset.from_list(ds_list) | ||
op = ImageShapeFilter(min_width=400, | ||
min_height=400, | ||
any_or_all='any') | ||
self._run_image_shape_filter(dataset, tgt_list, op) | ||
|
||
def test_all(self): | ||
|
||
ds_list = [{ | ||
'images': [self.img1_path, self.img2_path] | ||
}, { | ||
'images': [self.img2_path, self.img3_path] | ||
}, { | ||
'images': [self.img1_path, self.img3_path] | ||
}] | ||
tgt_list = [] | ||
dataset = Dataset.from_list(ds_list) | ||
op = ImageShapeFilter(min_width=400, | ||
min_height=400, | ||
any_or_all='all') | ||
self._run_image_shape_filter(dataset, tgt_list, op) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |