diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 708db5c22..0ed0b134f 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -108,6 +108,12 @@ process: rep_len: 10 # repetition length for char-level n-gram min_ratio: 0.0 # the min ratio of filter range max_ratio: 0.5 # the max ratio of filter range + - clip_similarity_filter: # filter samples according to the similarity between text and images. + hf_clip: # name of used Hugging Face clip + min_ratio: 0.24 # the min similarity of filter range + max_ratio: 1.0 # the max similarity of filter range + reduce_mode: avg # reduce mode when one text corresponds to multiple images in a chunk, must be one of ['avg','max', 'min']. + any_or_all: any # keep this sample when any/all images meet the filter condition - flagged_words_filter: # filter text with the flagged-word ratio larger than a specific max value lang: en # consider flagged words in what language tokenization: false # whether to use model to tokenize documents diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index c9332eea0..40b1e9327 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -1,8 +1,8 @@ from . import (alphanumeric_filter, average_line_length_filter, - character_repetition_filter, flagged_words_filter, - image_aspect_ratio_filter, language_id_score_filter, - maximum_line_length_filter, perplexity_filter, - special_characters_filter, specified_field_filter, - specified_numeric_field_filter, stopwords_filter, suffix_filter, - text_length_filter, token_num_filter, word_num_filter, - word_repetition_filter) + character_repetition_filter, clip_similarity_filter, + flagged_words_filter, image_aspect_ratio_filter, + language_id_score_filter, maximum_line_length_filter, + perplexity_filter, special_characters_filter, + specified_field_filter, specified_numeric_field_filter, + stopwords_filter, suffix_filter, text_length_filter, + token_num_filter, word_num_filter, word_repetition_filter) diff --git a/data_juicer/ops/filter/clip_similarity_filter.py b/data_juicer/ops/filter/clip_similarity_filter.py new file mode 100644 index 000000000..a62d3b320 --- /dev/null +++ b/data_juicer/ops/filter/clip_similarity_filter.py @@ -0,0 +1,152 @@ +import numpy as np +from jsonargparse.typing import PositiveFloat + +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import SpecialTokens, load_image +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Filter +from ..op_fusion import LOADED_IMAGES + + +@OPERATORS.register_module('clip_similarity_filter') +@LOADED_IMAGES.register_module('clip_similarity_filter') +class ClipSimilarityFilter(Filter): + """Filter to keep samples those similarity between image and text + within a specific range.""" + + def __init__(self, + hf_clip='openai/clip-vit-base-patch32', + min_ratio: PositiveFloat = 0.1, + max_ratio: PositiveFloat = 1.0, + any_or_all: str = 'any', + reduce_mode: str = 'avg', + *args, + **kwargs): + """ + Initialization method. + + :param hf_clip: clip model name on huggingface to compute + the similarity between image and text. + :param min_ratio: The min similarity to keep samples. + :param max_ratio: The max similarity to keep samples. + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all images. 'any': keep this sample if any images meet the + condition. 'all': keep this sample only if all images meet the + condition. + :param reduce_mode: reduce mode when one text corresponds to + multiple images in a chunk. + 'avg': Take the average of multiple values + 'max': Take the max of multiple values + 'min': Take the min of multiple values + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.image_key = 'images' + self.min_ratio = min_ratio + self.max_ratio = max_ratio + if reduce_mode not in ['avg', 'max', 'min']: + raise ValueError(f'Reduce mode [{reduce_mode}] is not supported. ' + f'Can only be one of ["avg", "max", "min"].') + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + self.model_key = prepare_model(model_type='hf_clip', model_key=hf_clip) + self.reduce_mode = reduce_mode + + def compute_stats(self, sample, context=False): + # check if it's computed already + if StatsKeys.clip_image_text_similarity in sample[Fields.stats]: + return sample + + # there is no image in this sample + if self.image_key not in sample or not sample[self.image_key]: + sample[Fields.stats][ + StatsKeys.clip_image_text_similarity] = np.array( + [], dtype=np.float64) + return sample + + # load images + loaded_image_keys = sample[self.image_key] + images = {} + for loaded_image_key in loaded_image_keys: + if context and loaded_image_key in sample[Fields.context]: + # load from context + images[loaded_image_key] = sample[ + Fields.context][loaded_image_key] + else: + if loaded_image_key not in images: + # avoid load the same images + image = load_image(loaded_image_key) + images[loaded_image_key] = image + if context: + # store the image data into context + sample[Fields.context][loaded_image_key] = image + + text = sample[self.text_key] + special_token_dict = { + key: value + for key, value in SpecialTokens.__dict__.items() + if not key.startswith('__') + } + offset = 0 + + def remove_special_token(text): + for key, value in special_token_dict.items(): + text = text.replace(value, '') + return text + + similarity = [] + model, processor = get_model(self.model_key) + + for chunk in text.split(SpecialTokens.eoc): + count = chunk.count(SpecialTokens.image) + + # no image or no text + if count == 0 or len(chunk) == 0: + continue + else: + text_chunk = remove_special_token(chunk) + image_chunk = [ + images[image_key] + for image_key in loaded_image_keys[offset:offset + count] + ] + + inputs = processor(text=text_chunk, + images=image_chunk, + return_tensors='pt', + padding=True) + + outputs = model(**inputs) + chunk_logits = outputs.logits_per_text.detach().cpu() / 100.0 + if self.reduce_mode == 'avg': + chunk_similarity = chunk_logits.mean() + elif self.reduce_mode == 'max': + chunk_similarity = chunk_logits.max() + else: + chunk_similarity = chunk_logits.min() + + similarity.append(float(chunk_similarity)) + offset += count + + sample[Fields.stats][StatsKeys.clip_image_text_similarity] = similarity + + return sample + + def process(self, sample): + similarity = sample[Fields.stats][StatsKeys.clip_image_text_similarity] + if len(similarity) <= 0: + return True + + keep_bools = np.array([ + self.min_ratio <= sim_value <= self.max_ratio + for sim_value in similarity + ]) + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index b2f3a362f..401d408de 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -1,6 +1,5 @@ DEFAULT_PREFIX = '__dj__' - class Fields(object): stats = DEFAULT_PREFIX + 'stats__' meta = DEFAULT_PREFIX + 'meta__' @@ -29,6 +28,9 @@ class StatsKeys(object): # image aspect_ratios = 'aspect_ratios' + # multimodal + clip_image_text_similarity = 'clip_image_text_similarity' + class HashKeys(object): hash = DEFAULT_PREFIX + 'hash' diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 3ac0d7973..e67b416f6 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -169,6 +169,21 @@ def prepare_huggingface_tokenizer(tokenizer_name): trust_remote_code=True) return tokenizer +def prepare_huggingface_clip(clip_name): + """ + Prepare and load a clip and processor from HuggingFace. + + :param tokenizer_name: input tokenizer name + :return: a tokenizer instance. + """ + from transformers import CLIPProcessor, CLIPModel + + model = CLIPModel.from_pretrained(clip_name) + processor = CLIPProcessor.from_pretrained(clip_name) + logger.info('Loading clip and processor from HuggingFace...') + + return (model, processor) + def prepare_diversity_model(model_name, lang): """ @@ -222,6 +237,7 @@ def prepare_model(lang='en', model_type='sentencepiece', model_key=None): 'kenlm': ('%s.arpa.bin', prepare_kenlm_model), 'nltk': ('punkt.%s.pickle', prepare_nltk_model), 'huggingface': ('%s', prepare_huggingface_tokenizer), + 'hf_clip': ('%s', prepare_huggingface_clip), 'spacy': ('%s_core_web_md-3.5.0', prepare_diversity_model), } assert model_type in type_to_name.keys( @@ -236,6 +252,11 @@ def prepare_model(lang='en', model_type='sentencepiece', model_key=None): MODEL_ZOO[model_key] = model_func(model_name) elif model_type == 'huggingface': MODEL_ZOO[model_key] = model_func(model_key) + elif model_type == 'hf_clip': + new_model_key = model_type + model_key + if new_model_key not in MODEL_ZOO.keys(): + MODEL_ZOO[new_model_key] = model_func(model_key) + model_key = new_model_key else: MODEL_ZOO[model_key] = model_func(model_name, lang) return model_key diff --git a/demos/overview_scan/app.py b/demos/overview_scan/app.py index 378b8f502..0c9383f0d 100644 --- a/demos/overview_scan/app.py +++ b/demos/overview_scan/app.py @@ -89,7 +89,7 @@ |-----------------------------------|:------:|-------------------------------------------------| | Formatter | 7 | Discovers, loads, and canonicalizes source data | | Mapper | 21 | Edits and transforms samples | -| Filter | 17 | Filters out low-quality samples | +| Filter | 18 | Filters out low-quality samples | | Deduplicator | 3 | Detects and removes duplicate samples | | Selector | 2 | Selects top samples based on ranking | ''' @@ -140,6 +140,7 @@ | alphanumeric_filter | General | en, zh | Keeps samples with alphanumeric ratio within the specified range | | average_line_length_filter | Code | en, zh | Keeps samples with average line length within the specified range | | character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range | +| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range | | flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold | | image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range | | language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score | diff --git a/docs/Operators.md b/docs/Operators.md index 78abeb495..6f4b56f23 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | | [ Mapper ]( #mapper ) | 21 | Edits and transforms samples | -| [ Filter ]( #filter ) | 17 | Filters out low-quality samples | +| [ Filter ]( #filter ) | 18 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 3 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 2 | Selects top samples based on ranking | @@ -23,6 +23,8 @@ All the specific operators are listed below, each featured with several capabili - LaTeX: specific to LaTeX source files - Code: specific to programming codes - Financial: closely related to financial sector + - Image: specific to image or multimodal + - Multimodal: specific to multimodal * Language Tags - en: English - zh: Chinese @@ -75,6 +77,7 @@ All the specific operators are listed below, each featured with several capabili | alphanumeric_filter | General | en, zh | Keeps samples with alphanumeric ratio within the specified range | | average_line_length_filter | Code | en, zh | Keeps samples with average line length within the specified range | | character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range | +| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range | | flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold | | image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range | | language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index cf3421d94..e59040fd2 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -21,6 +21,9 @@ Data-Juicer 中的算子分为以下 5 种类型。 - LaTeX: 专用于 LaTeX 源文件 - Code: 专用于编程代码 - Financial: 与金融领域相关 + - Image: 专用于图像或多模态 + - Multimodal: 专用于多模态 + * Language 标签 - en: 英文 - zh: 中文 @@ -71,6 +74,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | alphanumeric_filter | General | en, zh | 保留字母数字比例在指定范围内的样本 | | average_line_length_filter | Code | en, zh | 保留平均行长度在指定范围内的样本 | | character_repetition_filter | General | en, zh | 保留 char-level n-gram 重复比率在指定范围内的样本 | +| clip_similarity_filter | Multimodal | - | 保留文本图像相似度在指定范围内的样本 | | flagged_words_filter | General | en, zh | 保留使标记字比率保持在指定阈值以下的样本 | | image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 | | language_id_score_filter | General | en, zh | 保留特定语言的样本,通过预测的置信度得分来判断 | diff --git a/tests/ops/data/cat.jpg b/tests/ops/data/cat.jpg new file mode 100644 index 000000000..e131e8ecd Binary files /dev/null and b/tests/ops/data/cat.jpg differ diff --git a/tests/ops/filter/test_clip_similarity_filter.py b/tests/ops/filter/test_clip_similarity_filter.py new file mode 100644 index 000000000..b17c3d613 --- /dev/null +++ b/tests/ops/filter/test_clip_similarity_filter.py @@ -0,0 +1,187 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.clip_similarity_filter import ( + ClipSimilarityFilter, SpecialTokens) +from data_juicer.utils.constant import Fields + + +class ClipSimilarityFilterTest(unittest.TestCase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + + cat_path = os.path.join(data_path, 'cat.jpg') + img3_path = os.path.join(data_path, 'img3.jpg') + hf_clip = '/Users/mazhijian/Documents/Project_2023/P01_LLM/C04_Data/clip-vit-base-patch32/' + + def _run_filter(self, dataset: Dataset, target_list, op): + + if Fields.stats not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=['text', 'images']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_no_eoc_special_token(self): + + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat', + 'images': [self.cat_path] + }, { + 'text': f'{SpecialTokens.image}a photo of a dog', + 'images': [self.cat_path] + }] + tgt_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat', + 'images': [self.cat_path] + }] + + dataset = Dataset.from_list(ds_list) + op = ClipSimilarityFilter(hf_clip=self.hf_clip, + reduce_mode='avg', + any_or_all='any', + min_ratio=0.2, + max_ratio=0.9) + self._run_filter(dataset, tgt_list, op) + + def test_eoc_special_token(self): + + ds_list = [{ + 'text': + f'{SpecialTokens.image}a photo of a cat{SpecialTokens.eoc}', + 'images': [self.cat_path] + }, { + 'text': f'{SpecialTokens.image}a photo of a dog', + 'images': [self.cat_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.image}a photo of a cat{SpecialTokens.eoc}', + 'images': [self.cat_path] + }] + + dataset = Dataset.from_list(ds_list) + op = ClipSimilarityFilter(hf_clip=self.hf_clip, + reduce_mode='avg', + any_or_all='any', + min_ratio=0.2, + max_ratio=0.9) + self._run_filter(dataset, tgt_list, op) + + def test_keep_any(self): + + ds_list = [{ + 'text': + f'{SpecialTokens.image}a photo of a cat {SpecialTokens.eoc} ' + f'{SpecialTokens.image}a photo of a dog {SpecialTokens.eoc}', + 'images': [self.cat_path, self.cat_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.image}a photo of a cat {SpecialTokens.eoc} ' + f'{SpecialTokens.image}a photo of a dog {SpecialTokens.eoc}', + 'images': [self.cat_path, self.cat_path] + }] + dataset = Dataset.from_list(ds_list) + op = ClipSimilarityFilter(hf_clip=self.hf_clip, + reduce_mode='avg', + any_or_all='any', + min_ratio=0.2, + max_ratio=0.9) + self._run_filter(dataset, tgt_list, op) + + def test_keep_all(self): + + ds_list = [{ + 'text': + f'{SpecialTokens.image}a photo of a cat {SpecialTokens.eoc} ' + f'{SpecialTokens.image}a photo of a dog {SpecialTokens.eoc}', + 'images': [self.cat_path, self.cat_path] + }] + tgt_list = [] + dataset = Dataset.from_list(ds_list) + op = ClipSimilarityFilter(hf_clip=self.hf_clip, + reduce_mode='avg', + any_or_all='all', + min_ratio=0.2, + max_ratio=0.9) + self._run_filter(dataset, tgt_list, op) + + def test_reduce_avg(self): + + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat ' + f'{SpecialTokens.image} {SpecialTokens.eoc}', + 'images': [self.cat_path, self.img3_path] + }] + tgt_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat ' + f'{SpecialTokens.image} {SpecialTokens.eoc}', + 'images': [self.cat_path, self.img3_path] + }] + dataset = Dataset.from_list(ds_list) + op = ClipSimilarityFilter(hf_clip=self.hf_clip, + reduce_mode='avg', + any_or_all='any', + min_ratio=0.2, + max_ratio=0.9) + self._run_filter(dataset, tgt_list, op) + + def xxtest_reduce_max(self): + + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat ' + f'{SpecialTokens.image} {SpecialTokens.eoc}', + 'images': [self.cat_path, self.img3_path] + }] + tgt_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat ' + f'{SpecialTokens.image} {SpecialTokens.eoc}', + 'images': [self.cat_path, self.img3_path] + }] + dataset = Dataset.from_list(ds_list) + op = ClipSimilarityFilter(hf_clip=self.hf_clip, + reduce_mode='max', + any_or_all='any', + min_ratio=0.2, + max_ratio=0.9) + self._run_filter(dataset, tgt_list, op) + + def test_reduce_min(self): + + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat ' + f'{SpecialTokens.image} {SpecialTokens.eoc}', + 'images': [self.cat_path, self.img3_path] + }] + tgt_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat ' + f'{SpecialTokens.image} {SpecialTokens.eoc}', + 'images': [self.cat_path, self.img3_path] + }] + + dataset = Dataset.from_list(ds_list) + op = ClipSimilarityFilter(hf_clip=self.hf_clip, + reduce_mode='min', + any_or_all='any', + min_ratio=0.1, + max_ratio=0.9) + + self._run_filter(dataset, tgt_list, op) + + op.min_ratio = 0.2 + self._run_filter(dataset, [], op) + + +if __name__ == '__main__': + unittest.main()