diff --git a/data_juicer/ops/filter/image_text_matching_filter.py b/data_juicer/ops/filter/image_text_matching_filter.py index 5e779087b..1fc458cea 100644 --- a/data_juicer/ops/filter/image_text_matching_filter.py +++ b/data_juicer/ops/filter/image_text_matching_filter.py @@ -1,9 +1,11 @@ import numpy as np from jsonargparse.typing import ClosedUnitInterval +from PIL import ImageOps from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys -from data_juicer.utils.mm_utils import SpecialTokens, load_image +from data_juicer.utils.mm_utils import (SpecialTokens, load_image, + remove_special_tokens) from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Filter @@ -11,10 +13,9 @@ OP_NAME = 'image_text_matching_filter' -with AvailabilityChecking(['torch'], OP_NAME): +with AvailabilityChecking(['torch', 'transformers'], OP_NAME): import torch import transformers # noqa: F401 - from PIL import ImageOps # avoid hanging when calling blip in multiprocessing torch.set_num_threads(1) @@ -102,18 +103,7 @@ def compute_stats(self, sample, context=False): sample[Fields.context][loaded_image_key] = image text = sample[self.text_key] - special_token_dict = { - key: value - for key, value in SpecialTokens.__dict__.items() - if not key.startswith('__') - } offset = 0 - - def remove_special_token(text): - for value in special_token_dict.values(): - text = text.replace(value, '') - return text - matching_scores = [] model, processor = get_model(self.model_key) @@ -124,7 +114,7 @@ def remove_special_token(text): if count == 0 or len(chunk) == 0: continue else: - text_chunk = remove_special_token(chunk) + text_chunk = remove_special_tokens(chunk) image_chunk = [] for image_key in loaded_image_keys[offset:offset + count]: image = images[image_key] @@ -172,7 +162,6 @@ def process(self, sample): # different strategies if self.any: - return keep_bools.any() else: return keep_bools.all() diff --git a/data_juicer/ops/filter/image_text_similarity_filter.py b/data_juicer/ops/filter/image_text_similarity_filter.py index 285f50cf1..50e0548e8 100644 --- a/data_juicer/ops/filter/image_text_similarity_filter.py +++ b/data_juicer/ops/filter/image_text_similarity_filter.py @@ -1,9 +1,11 @@ import numpy as np from jsonargparse.typing import ClosedUnitInterval +from PIL import ImageOps from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys -from data_juicer.utils.mm_utils import SpecialTokens, load_image +from data_juicer.utils.mm_utils import (SpecialTokens, load_image, + remove_special_tokens) from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Filter @@ -11,11 +13,10 @@ OP_NAME = 'image_text_similarity_filter' -with AvailabilityChecking(['torch'], OP_NAME): +with AvailabilityChecking(['torch', 'transformers'], OP_NAME): import torch import transformers # noqa: F401 - from PIL import ImageOps # avoid hanging when calling clip in multiprocessing torch.set_num_threads(1) @@ -102,18 +103,7 @@ def compute_stats(self, sample, context=False): sample[Fields.context][loaded_image_key] = image text = sample[self.text_key] - special_token_dict = { - key: value - for key, value in SpecialTokens.__dict__.items() - if not key.startswith('__') - } offset = 0 - - def remove_special_token(text): - for value in special_token_dict.values(): - text = text.replace(value, '') - return text - similarity = [] model, processor = get_model(self.model_key) @@ -124,7 +114,7 @@ def remove_special_token(text): if count == 0 or len(chunk) == 0: continue else: - text_chunk = remove_special_token(chunk) + text_chunk = remove_special_tokens(chunk) image_chunk = [] for image_key in loaded_image_keys[offset:offset + count]: image = images[image_key] diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py index 817f298bd..dc216f031 100644 --- a/data_juicer/utils/mm_utils.py +++ b/data_juicer/utils/mm_utils.py @@ -14,6 +14,21 @@ class SpecialTokens(object): eoc = f'<|{DEFAULT_PREFIX}eoc|>' +def get_special_tokens(): + special_token_dict = { + key: value + for key, value in SpecialTokens.__dict__.items() + if not key.startswith('__') + } + return special_token_dict + + +def remove_special_tokens(text): + for value in get_special_tokens().values(): + text = text.replace(value, '') + return text + + def load_images(paths): return [load_image(path) for path in paths] diff --git a/docs/Operators.md b/docs/Operators.md index f5500f702..01676141c 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -81,8 +81,8 @@ All the specific operators are listed below, each featured with several capabili | image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range | | image_shape_filter | Image | - | Keeps samples contains images with widths and heights within specific ranges | | image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within specific range | -| image_text_matching_filter | Multimodal | - | Keeps samples with matching score between text and images within the specified range | -| image_text_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range | +| image_text_matching_filter | Multimodal | - | Keeps samples with image-text classification matching score within the specified range based on a BLIP model | +| image_text_similarity_filter | Multimodal | - | Keeps samples with image-text feature cosine similarity within the specified range based on a CLIP model | | language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score | | maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range | | perplexity_filter | General | en, zh | Keeps samples with perplexity score below the specified threshold | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 379607324..071f206f3 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -78,8 +78,8 @@ Data-Juicer 中的算子分为以下 5 种类型。 | image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 | | image_shape_filter | Image | - | 保留样本中包含的图片的形状(即宽和高)在指定范围内的样本 | | image_size_filter | Image | - | 保留样本中包含的图片的大小(bytes)在指定范围内的样本 | -| image_text_matching_filter | Multimodal | - | 保留文本图像匹配度在指定范围内的样本 | -| image_text_similarity_filter | Multimodal | - | 保留文本图像相似度在指定范围内的样本 | +| image_text_matching_filter | Multimodal | - | 保留图像-文本的分类匹配分(基于BLIP模型)在指定范围内的样本 | +| image_text_similarity_filter | Multimodal | - | 保留图像-文本的特征余弦相似度(基于CLIP模型)在指定范围内的样本 | | language_id_score_filter | General | en, zh | 保留特定语言的样本,通过预测的置信度得分来判断 | | maximum_line_length_filter | Code | en, zh | 保留最大行长度在指定范围内的样本 | | perplexity_filter | General | en, zh | 保留困惑度低于指定阈值的样本 |