diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index 056da04cd..abce40a5b 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -17,7 +17,8 @@ video_frames_text_similarity_filter, video_motion_score_filter, video_nsfw_filter, video_ocr_area_ratio_filter, video_resolution_filter, video_tagging_from_frames_filter, - video_watermark_filter, word_num_filter, word_repetition_filter) + video_watermark_filter, word_repetition_filter, + words_num_filter) from .alphanumeric_filter import AlphanumericFilter from .audio_duration_filter import AudioDurationFilter from .audio_nmf_snr_filter import AudioNMFSNRFilter @@ -58,8 +59,8 @@ from .video_resolution_filter import VideoResolutionFilter from .video_tagging_from_frames_filter import VideoTaggingFromFramesFilter from .video_watermark_filter import VideoWatermarkFilter -from .word_num_filter import WordNumFilter from .word_repetition_filter import WordRepetitionFilter +from .words_num_filter import WordsNumFilter __all__ = [ 'ImageTextSimilarityFilter', @@ -98,7 +99,7 @@ 'SuffixFilter', 'ImageSizeFilter', 'VideoWatermarkFilter', - 'WordNumFilter', + 'WordsNumFilter', 'ImageFaceRatioFilter', 'FlaggedWordFilter', 'WordRepetitionFilter', diff --git a/data_juicer/ops/filter/image_nsfw_filter.py b/data_juicer/ops/filter/image_nsfw_filter.py index 0d3a18b7d..b08990f32 100644 --- a/data_juicer/ops/filter/image_nsfw_filter.py +++ b/data_juicer/ops/filter/image_nsfw_filter.py @@ -77,7 +77,9 @@ def compute_stats(self, sample, rank=None, context=False): inputs = processor(images=images, return_tensors='pt').to(model.device) outputs = model(**inputs) logits = outputs.logits - nsfw_scores = [scores[1] for scores in torch.softmax(logits, dim=-1)] + nsfw_scores = [ + float(scores[1]) for scores in torch.softmax(logits, dim=-1) + ] sample[Fields.stats][StatsKeys.image_nsfw_score] = nsfw_scores diff --git a/data_juicer/ops/filter/image_watermark_filter.py b/data_juicer/ops/filter/image_watermark_filter.py index 1b7cd98d2..24e69c17f 100644 --- a/data_juicer/ops/filter/image_watermark_filter.py +++ b/data_juicer/ops/filter/image_watermark_filter.py @@ -81,7 +81,9 @@ def compute_stats(self, sample, rank=None, context=False): inputs = processor(images=images, return_tensors='pt').to(model.device) outputs = model(**inputs) logits = outputs.logits - watermark_probs = [probs[1] for probs in torch.softmax(logits, dim=-1)] + watermark_probs = [ + float(probs[1]) for probs in torch.softmax(logits, dim=-1) + ] sample[Fields.stats][StatsKeys.image_watermark_prob] = watermark_probs diff --git a/data_juicer/ops/filter/video_aspect_ratio_filter.py b/data_juicer/ops/filter/video_aspect_ratio_filter.py index 4bea08827..8d1e654a2 100644 --- a/data_juicer/ops/filter/video_aspect_ratio_filter.py +++ b/data_juicer/ops/filter/video_aspect_ratio_filter.py @@ -64,9 +64,8 @@ def compute_stats(self, sample, context=False): video_aspect_ratios = {} for key, video in videos.items(): stream = video.streams.video[0] - video_aspect_ratios[key] = str( - Fraction(stream.codec_context.width, - stream.codec_context.height)) + video_aspect_ratios[ + key] = stream.codec_context.width / stream.codec_context.height if not context: video.close() diff --git a/data_juicer/ops/filter/video_nsfw_filter.py b/data_juicer/ops/filter/video_nsfw_filter.py index 27b475c53..f5244bf8b 100644 --- a/data_juicer/ops/filter/video_nsfw_filter.py +++ b/data_juicer/ops/filter/video_nsfw_filter.py @@ -144,11 +144,12 @@ def compute_stats(self, sample, rank=None, context=False): cur_scores = torch.Tensor(cur_scores) if self.reduce_mode == 'avg': - nsfw_scores.append(cur_scores.mean()) + cur_score = cur_scores.mean() elif self.reduce_mode == 'max': - nsfw_scores.append(cur_scores.max()) + cur_score = cur_scores.max() else: - nsfw_scores.append(cur_scores.min()) + cur_score = cur_scores.min() + nsfw_scores.append(float(cur_score)) sample[Fields.stats][StatsKeys.video_nsfw_score] = nsfw_scores diff --git a/data_juicer/ops/filter/video_watermark_filter.py b/data_juicer/ops/filter/video_watermark_filter.py index 12dce3f4f..a07d12ca3 100644 --- a/data_juicer/ops/filter/video_watermark_filter.py +++ b/data_juicer/ops/filter/video_watermark_filter.py @@ -145,11 +145,12 @@ def compute_stats(self, sample, rank=None, context=False): cur_probs = torch.Tensor(cur_probs) if self.reduce_mode == 'avg': - watermark_probs.append(cur_probs.mean()) + cur_prob = cur_probs.mean() elif self.reduce_mode == 'max': - watermark_probs.append(cur_probs.max()) + cur_prob = cur_probs.max() else: - watermark_probs.append(cur_probs.min()) + cur_prob = cur_probs.min() + watermark_probs.append(float(cur_prob)) sample[Fields.stats][StatsKeys.video_watermark_prob] = watermark_probs diff --git a/data_juicer/ops/filter/word_num_filter.py b/data_juicer/ops/filter/words_num_filter.py similarity index 98% rename from data_juicer/ops/filter/word_num_filter.py rename to data_juicer/ops/filter/words_num_filter.py index 08ec90ee8..3f6d02d76 100644 --- a/data_juicer/ops/filter/word_num_filter.py +++ b/data_juicer/ops/filter/words_num_filter.py @@ -19,7 +19,7 @@ @OPERATORS.register_module(OP_NAME) @INTER_WORDS.register_module(OP_NAME) -class WordNumFilter(Filter): +class WordsNumFilter(Filter): """Filter to keep samples with total words number within a specific range.""" diff --git a/docs/BadDataExhibition.md b/docs/BadDataExhibition.md index c4483a0bc..720ab25b9 100644 --- a/docs/BadDataExhibition.md +++ b/docs/BadDataExhibition.md @@ -39,7 +39,7 @@ some OPs work well on some datasets but might be useless on others. | `perplexity_filter` | [LCS-558K](#llava-15-pretrain-dataset), [Books](#books), [ArXiv](#arxiv) | | `special_characters_filter` | [Wikipedia](#wikipedia), [Books](#books) | | `text_length_filter` | [Wikipedia](#wikipedia), [ArXiv](#arxiv), [Github Code](#github-code) | -| `word_num_filter` | [Stack Exchange](#stack-exchange) | +| `words_num_filter` | [Stack Exchange](#stack-exchange) | | `word_repetition_filter` | [MMC4](#mmc4), [Wikipedia](#wikipedia) | - Everyone from community is welcome to continue to add examples to this table. @@ -298,7 +298,7 @@ The pretraining dataset of LLaVA-1.5 | item | value | |-----------|------------------------------------------------------------------| -| from OP | `word_num_filter` | +| from OP | `words_num_filter` | | num_words | 2 | | text |
text
Q: Append
| | comments | Texts with too few words might be missing content | diff --git a/docs/BadDataExhibition_ZH.md b/docs/BadDataExhibition_ZH.md index 1a73def29..c1021afef 100644 --- a/docs/BadDataExhibition_ZH.md +++ b/docs/BadDataExhibition_ZH.md @@ -35,7 +35,7 @@ | `perplexity_filter` | [LCS-558K](#llava-15-pretrain-dataset), [Books](#books), [ArXiv](#arxiv) | | `special_characters_filter` | [Wikipedia](#wikipedia), [Books](#books) | | `text_length_filter` | [Wikipedia](#wikipedia), [ArXiv](#arxiv), [Github Code](#github-code) | -| `word_num_filter` | [Stack Exchange](#stack-exchange) | +| `words_num_filter` | [Stack Exchange](#stack-exchange) | | `word_repetition_filter` | [MMC4](#mmc4), [Wikipedia](#wikipedia) | - 欢迎大家持续补充这个表格。 @@ -292,7 +292,7 @@ LLaVA-1.5 的预训练数据集。 | 数据项 | 值 | |-----------|------------------------------------------------------------------| -| 来源算子 | `word_num_filter` | +| 来源算子 | `words_num_filter` | | num_words | 2 | | 文本 |
text
Q: Append
| | 说明 | 单词数量过少的文本可能内容缺失 | diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md index cb53ca327..e887b0358 100644 --- a/docs/DeveloperGuide.md +++ b/docs/DeveloperGuide.md @@ -289,7 +289,7 @@ the corresponding documents, including the following docs: ### (Optional) Make your OP fusible - If the calculation process of some intermediate variables in the new OP is reused in other existing OPs, this new OP can be -added to the fusible OPs to accelerate the whole data processing with OP fusion technology. (e.g. both the `word_num_filter` +added to the fusible OPs to accelerate the whole data processing with OP fusion technology. (e.g. both the `words_num_filter` and `word_repetition_filter` need to split the input text into words) - When opening OP fusion, these reused calculation processes and intermediate variables can be shared in the `context` between OPs, thus reducing repeated calculations. @@ -332,7 +332,7 @@ to this intermediate variable, indicating that the intermediate variable may be ... @OPERATORS.register_module(OP_NAME) @INTER_WORDS.register_module(OP_NAME) # register this new OP into the registry group -class WordNumFilter(Filter): +class WordsNumFilter(Filter): ... ``` diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md index 6af83c3a9..9046258b1 100644 --- a/docs/DeveloperGuide_ZH.md +++ b/docs/DeveloperGuide_ZH.md @@ -276,9 +276,9 @@ if __name__ == '__main__': ### (可选)使新算子可以进行算子融合 -- 如果我们的新算子中的部分中间变量的计算过程与已有的算子重复,那么可以将其添加到可融合算子中,以在数据处理时利用算子融合进行加速。(如`word_num_filter`与`word_repetition_filter`都需要对输入文本进行分词) +- 如果我们的新算子中的部分中间变量的计算过程与已有的算子重复,那么可以将其添加到可融合算子中,以在数据处理时利用算子融合进行加速。(如`words_num_filter`与`word_repetition_filter`都需要对输入文本进行分词) - 当算子融合(OP Fusion)功能开启时,这些重复的计算过程和中间变量是可以在算子之间的`context`中共享的,从而可以减少重复计算。 -- 可通过如下步骤使包含共有中间变量的算子可进行算子融合(以`word_num_filter`算子为例)。 +- 可通过如下步骤使包含共有中间变量的算子可进行算子融合(以`words_num_filter`算子为例)。 1. (可选)如果新算子中产生了新的中间变量,需要在`utils/constant.py`中的`InterVars`类中添加新的中间变量名称。通常需要在名称前加上`DEFAULT_PREFIX`前缀。 @@ -313,7 +313,7 @@ ALL_INTER_VARS = [INTER_LINES, INTER_WORDS, LOADED_IMAGES] # 并添加到注册 ... @OPERATORS.register_module(OP_NAME) @INTER_WORDS.register_module(OP_NAME) # 将该算子注册到注册组中 -class WordNumFilter(Filter): +class WordsNumFilter(Filter): ... ``` diff --git a/docs/Operators.md b/docs/Operators.md index 016ab97df..db3e9ac50 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -138,7 +138,7 @@ All the specific operators are listed below, each featured with several capabili | video_resolution_filter | Video | - | Keeps samples containing videos with horizontal and vertical resolutions within the specified range | | video_watermark_filter | Video | - | Keeps samples containing videos with predicted watermark probabilities below the threshold | | video_tagging_from_frames_filter | Video | - | Keep samples containing videos with given tags | -| word_num_filter | General | en, zh | Keeps samples with word count within the specified range | +| words_num_filter | General | en, zh | Keeps samples with word count within the specified range | | word_repetition_filter | General | en, zh | Keeps samples with word-level n-gram repetition ratio within the specified range | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 8b1d50a39..e0a1ece87 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -136,7 +136,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | video_resolution_filter | Video | - | 保留包含视频的分辨率(包括横向分辨率和纵向分辨率)在指定范围内的样本 | | video_watermark_filter | Video | - | 保留包含视频有水印的概率在指定阈值之下的样本| | video_tagging_from_frames_filter | Video | - | 保留包含具有给定标签视频的样本 | -| word_num_filter | General | en, zh | 保留字数在指定范围内的样本 | +| words_num_filter | General | en, zh | 保留字数在指定范围内的样本 | | word_repetition_filter | General | en, zh | 保留 word-level n-gram 重复比率在指定范围内的样本 | ## Deduplicator diff --git a/tests/ops/filter/test_token_num_filter.py b/tests/ops/filter/test_token_num_filter.py index 514ce21c3..5ee78bab2 100644 --- a/tests/ops/filter/test_token_num_filter.py +++ b/tests/ops/filter/test_token_num_filter.py @@ -7,7 +7,7 @@ from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class WordNumFilterTest(DataJuicerTestCaseBase): +class TokenNumFilterTest(DataJuicerTestCaseBase): def test_token_num(self): src = [ diff --git a/tests/ops/filter/test_word_num_filter.py b/tests/ops/filter/test_word_num_filter.py index 6a4967f97..6b4522c5e 100644 --- a/tests/ops/filter/test_word_num_filter.py +++ b/tests/ops/filter/test_word_num_filter.py @@ -2,14 +2,14 @@ from datasets import Dataset -from data_juicer.ops.filter.word_num_filter import WordNumFilter +from data_juicer.ops.filter.words_num_filter import WordsNumFilter from data_juicer.utils.constant import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class WordNumFilterTest(DataJuicerTestCaseBase): +class WordsNumFilterTest(DataJuicerTestCaseBase): - def _run_word_num_filter(self, dataset: Dataset, target_list, op): + def _run_words_num_filter(self, dataset: Dataset, target_list, op): if Fields.stats not in dataset.features: # TODO: # this is a temp solution, @@ -41,8 +41,8 @@ def test_case(self): 'text': 'a v s e c s f e f g a a a ' }] dataset = Dataset.from_list(ds_list) - op = WordNumFilter(min_num=5, max_num=15) - self._run_word_num_filter(dataset, tgt_list, op) + op = WordsNumFilter(min_num=5, max_num=15) + self._run_words_num_filter(dataset, tgt_list, op) def test_zh_case(self): @@ -65,11 +65,11 @@ def test_zh_case(self): 'text': '基于前一步结果,在同一个聚类中找出那些过长文档为假正例,暂不进行滤除' }] dataset = Dataset.from_list(ds_list) - op = WordNumFilter(lang='zh', + op = WordsNumFilter(lang='zh', tokenization=True, min_num=10, max_num=25) - self._run_word_num_filter(dataset, tgt_list, op) + self._run_words_num_filter(dataset, tgt_list, op) if __name__ == '__main__':