diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py
index 056da04cd..abce40a5b 100644
--- a/data_juicer/ops/filter/__init__.py
+++ b/data_juicer/ops/filter/__init__.py
@@ -17,7 +17,8 @@
video_frames_text_similarity_filter, video_motion_score_filter,
video_nsfw_filter, video_ocr_area_ratio_filter,
video_resolution_filter, video_tagging_from_frames_filter,
- video_watermark_filter, word_num_filter, word_repetition_filter)
+ video_watermark_filter, word_repetition_filter,
+ words_num_filter)
from .alphanumeric_filter import AlphanumericFilter
from .audio_duration_filter import AudioDurationFilter
from .audio_nmf_snr_filter import AudioNMFSNRFilter
@@ -58,8 +59,8 @@
from .video_resolution_filter import VideoResolutionFilter
from .video_tagging_from_frames_filter import VideoTaggingFromFramesFilter
from .video_watermark_filter import VideoWatermarkFilter
-from .word_num_filter import WordNumFilter
from .word_repetition_filter import WordRepetitionFilter
+from .words_num_filter import WordsNumFilter
__all__ = [
'ImageTextSimilarityFilter',
@@ -98,7 +99,7 @@
'SuffixFilter',
'ImageSizeFilter',
'VideoWatermarkFilter',
- 'WordNumFilter',
+ 'WordsNumFilter',
'ImageFaceRatioFilter',
'FlaggedWordFilter',
'WordRepetitionFilter',
diff --git a/data_juicer/ops/filter/image_nsfw_filter.py b/data_juicer/ops/filter/image_nsfw_filter.py
index 0d3a18b7d..b08990f32 100644
--- a/data_juicer/ops/filter/image_nsfw_filter.py
+++ b/data_juicer/ops/filter/image_nsfw_filter.py
@@ -77,7 +77,9 @@ def compute_stats(self, sample, rank=None, context=False):
inputs = processor(images=images, return_tensors='pt').to(model.device)
outputs = model(**inputs)
logits = outputs.logits
- nsfw_scores = [scores[1] for scores in torch.softmax(logits, dim=-1)]
+ nsfw_scores = [
+ float(scores[1]) for scores in torch.softmax(logits, dim=-1)
+ ]
sample[Fields.stats][StatsKeys.image_nsfw_score] = nsfw_scores
diff --git a/data_juicer/ops/filter/image_watermark_filter.py b/data_juicer/ops/filter/image_watermark_filter.py
index 1b7cd98d2..24e69c17f 100644
--- a/data_juicer/ops/filter/image_watermark_filter.py
+++ b/data_juicer/ops/filter/image_watermark_filter.py
@@ -81,7 +81,9 @@ def compute_stats(self, sample, rank=None, context=False):
inputs = processor(images=images, return_tensors='pt').to(model.device)
outputs = model(**inputs)
logits = outputs.logits
- watermark_probs = [probs[1] for probs in torch.softmax(logits, dim=-1)]
+ watermark_probs = [
+ float(probs[1]) for probs in torch.softmax(logits, dim=-1)
+ ]
sample[Fields.stats][StatsKeys.image_watermark_prob] = watermark_probs
diff --git a/data_juicer/ops/filter/video_aspect_ratio_filter.py b/data_juicer/ops/filter/video_aspect_ratio_filter.py
index 4bea08827..8d1e654a2 100644
--- a/data_juicer/ops/filter/video_aspect_ratio_filter.py
+++ b/data_juicer/ops/filter/video_aspect_ratio_filter.py
@@ -64,9 +64,8 @@ def compute_stats(self, sample, context=False):
video_aspect_ratios = {}
for key, video in videos.items():
stream = video.streams.video[0]
- video_aspect_ratios[key] = str(
- Fraction(stream.codec_context.width,
- stream.codec_context.height))
+ video_aspect_ratios[
+ key] = stream.codec_context.width / stream.codec_context.height
if not context:
video.close()
diff --git a/data_juicer/ops/filter/video_nsfw_filter.py b/data_juicer/ops/filter/video_nsfw_filter.py
index 27b475c53..f5244bf8b 100644
--- a/data_juicer/ops/filter/video_nsfw_filter.py
+++ b/data_juicer/ops/filter/video_nsfw_filter.py
@@ -144,11 +144,12 @@ def compute_stats(self, sample, rank=None, context=False):
cur_scores = torch.Tensor(cur_scores)
if self.reduce_mode == 'avg':
- nsfw_scores.append(cur_scores.mean())
+ cur_score = cur_scores.mean()
elif self.reduce_mode == 'max':
- nsfw_scores.append(cur_scores.max())
+ cur_score = cur_scores.max()
else:
- nsfw_scores.append(cur_scores.min())
+ cur_score = cur_scores.min()
+ nsfw_scores.append(float(cur_score))
sample[Fields.stats][StatsKeys.video_nsfw_score] = nsfw_scores
diff --git a/data_juicer/ops/filter/video_watermark_filter.py b/data_juicer/ops/filter/video_watermark_filter.py
index 12dce3f4f..a07d12ca3 100644
--- a/data_juicer/ops/filter/video_watermark_filter.py
+++ b/data_juicer/ops/filter/video_watermark_filter.py
@@ -145,11 +145,12 @@ def compute_stats(self, sample, rank=None, context=False):
cur_probs = torch.Tensor(cur_probs)
if self.reduce_mode == 'avg':
- watermark_probs.append(cur_probs.mean())
+ cur_prob = cur_probs.mean()
elif self.reduce_mode == 'max':
- watermark_probs.append(cur_probs.max())
+ cur_prob = cur_probs.max()
else:
- watermark_probs.append(cur_probs.min())
+ cur_prob = cur_probs.min()
+ watermark_probs.append(float(cur_prob))
sample[Fields.stats][StatsKeys.video_watermark_prob] = watermark_probs
diff --git a/data_juicer/ops/filter/word_num_filter.py b/data_juicer/ops/filter/words_num_filter.py
similarity index 98%
rename from data_juicer/ops/filter/word_num_filter.py
rename to data_juicer/ops/filter/words_num_filter.py
index 08ec90ee8..3f6d02d76 100644
--- a/data_juicer/ops/filter/word_num_filter.py
+++ b/data_juicer/ops/filter/words_num_filter.py
@@ -19,7 +19,7 @@
@OPERATORS.register_module(OP_NAME)
@INTER_WORDS.register_module(OP_NAME)
-class WordNumFilter(Filter):
+class WordsNumFilter(Filter):
"""Filter to keep samples with total words number within a specific
range."""
diff --git a/docs/BadDataExhibition.md b/docs/BadDataExhibition.md
index c4483a0bc..720ab25b9 100644
--- a/docs/BadDataExhibition.md
+++ b/docs/BadDataExhibition.md
@@ -39,7 +39,7 @@ some OPs work well on some datasets but might be useless on others.
| `perplexity_filter` | [LCS-558K](#llava-15-pretrain-dataset), [Books](#books), [ArXiv](#arxiv) |
| `special_characters_filter` | [Wikipedia](#wikipedia), [Books](#books) |
| `text_length_filter` | [Wikipedia](#wikipedia), [ArXiv](#arxiv), [Github Code](#github-code) |
-| `word_num_filter` | [Stack Exchange](#stack-exchange) |
+| `words_num_filter` | [Stack Exchange](#stack-exchange) |
| `word_repetition_filter` | [MMC4](#mmc4), [Wikipedia](#wikipedia) |
- Everyone from community is welcome to continue to add examples to this table.
@@ -298,7 +298,7 @@ The pretraining dataset of LLaVA-1.5
| item | value |
|-----------|------------------------------------------------------------------|
-| from OP | `word_num_filter` |
+| from OP | `words_num_filter` |
| num_words | 2 |
| text | text
Q: Append
|
| comments | Texts with too few words might be missing content |
diff --git a/docs/BadDataExhibition_ZH.md b/docs/BadDataExhibition_ZH.md
index 1a73def29..c1021afef 100644
--- a/docs/BadDataExhibition_ZH.md
+++ b/docs/BadDataExhibition_ZH.md
@@ -35,7 +35,7 @@
| `perplexity_filter` | [LCS-558K](#llava-15-pretrain-dataset), [Books](#books), [ArXiv](#arxiv) |
| `special_characters_filter` | [Wikipedia](#wikipedia), [Books](#books) |
| `text_length_filter` | [Wikipedia](#wikipedia), [ArXiv](#arxiv), [Github Code](#github-code) |
-| `word_num_filter` | [Stack Exchange](#stack-exchange) |
+| `words_num_filter` | [Stack Exchange](#stack-exchange) |
| `word_repetition_filter` | [MMC4](#mmc4), [Wikipedia](#wikipedia) |
- 欢迎大家持续补充这个表格。
@@ -292,7 +292,7 @@ LLaVA-1.5 的预训练数据集。
| 数据项 | 值 |
|-----------|------------------------------------------------------------------|
-| 来源算子 | `word_num_filter` |
+| 来源算子 | `words_num_filter` |
| num_words | 2 |
| 文本 | text
Q: Append
|
| 说明 | 单词数量过少的文本可能内容缺失 |
diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md
index cb53ca327..e887b0358 100644
--- a/docs/DeveloperGuide.md
+++ b/docs/DeveloperGuide.md
@@ -289,7 +289,7 @@ the corresponding documents, including the following docs:
### (Optional) Make your OP fusible
- If the calculation process of some intermediate variables in the new OP is reused in other existing OPs, this new OP can be
-added to the fusible OPs to accelerate the whole data processing with OP fusion technology. (e.g. both the `word_num_filter`
+added to the fusible OPs to accelerate the whole data processing with OP fusion technology. (e.g. both the `words_num_filter`
and `word_repetition_filter` need to split the input text into words)
- When opening OP fusion, these reused calculation processes and intermediate variables can be shared in the `context` between
OPs, thus reducing repeated calculations.
@@ -332,7 +332,7 @@ to this intermediate variable, indicating that the intermediate variable may be
...
@OPERATORS.register_module(OP_NAME)
@INTER_WORDS.register_module(OP_NAME) # register this new OP into the registry group
-class WordNumFilter(Filter):
+class WordsNumFilter(Filter):
...
```
diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md
index 6af83c3a9..9046258b1 100644
--- a/docs/DeveloperGuide_ZH.md
+++ b/docs/DeveloperGuide_ZH.md
@@ -276,9 +276,9 @@ if __name__ == '__main__':
### (可选)使新算子可以进行算子融合
-- 如果我们的新算子中的部分中间变量的计算过程与已有的算子重复,那么可以将其添加到可融合算子中,以在数据处理时利用算子融合进行加速。(如`word_num_filter`与`word_repetition_filter`都需要对输入文本进行分词)
+- 如果我们的新算子中的部分中间变量的计算过程与已有的算子重复,那么可以将其添加到可融合算子中,以在数据处理时利用算子融合进行加速。(如`words_num_filter`与`word_repetition_filter`都需要对输入文本进行分词)
- 当算子融合(OP Fusion)功能开启时,这些重复的计算过程和中间变量是可以在算子之间的`context`中共享的,从而可以减少重复计算。
-- 可通过如下步骤使包含共有中间变量的算子可进行算子融合(以`word_num_filter`算子为例)。
+- 可通过如下步骤使包含共有中间变量的算子可进行算子融合(以`words_num_filter`算子为例)。
1. (可选)如果新算子中产生了新的中间变量,需要在`utils/constant.py`中的`InterVars`类中添加新的中间变量名称。通常需要在名称前加上`DEFAULT_PREFIX`前缀。
@@ -313,7 +313,7 @@ ALL_INTER_VARS = [INTER_LINES, INTER_WORDS, LOADED_IMAGES] # 并添加到注册
...
@OPERATORS.register_module(OP_NAME)
@INTER_WORDS.register_module(OP_NAME) # 将该算子注册到注册组中
-class WordNumFilter(Filter):
+class WordsNumFilter(Filter):
...
```
diff --git a/docs/Operators.md b/docs/Operators.md
index 016ab97df..db3e9ac50 100644
--- a/docs/Operators.md
+++ b/docs/Operators.md
@@ -138,7 +138,7 @@ All the specific operators are listed below, each featured with several capabili
| video_resolution_filter | Video | - | Keeps samples containing videos with horizontal and vertical resolutions within the specified range |
| video_watermark_filter | Video | - | Keeps samples containing videos with predicted watermark probabilities below the threshold |
| video_tagging_from_frames_filter | Video | - | Keep samples containing videos with given tags |
-| word_num_filter | General | en, zh | Keeps samples with word count within the specified range |
+| words_num_filter | General | en, zh | Keeps samples with word count within the specified range |
| word_repetition_filter | General | en, zh | Keeps samples with word-level n-gram repetition ratio within the specified range |
diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md
index 8b1d50a39..e0a1ece87 100644
--- a/docs/Operators_ZH.md
+++ b/docs/Operators_ZH.md
@@ -136,7 +136,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| video_resolution_filter | Video | - | 保留包含视频的分辨率(包括横向分辨率和纵向分辨率)在指定范围内的样本 |
| video_watermark_filter | Video | - | 保留包含视频有水印的概率在指定阈值之下的样本|
| video_tagging_from_frames_filter | Video | - | 保留包含具有给定标签视频的样本 |
-| word_num_filter | General | en, zh | 保留字数在指定范围内的样本 |
+| words_num_filter | General | en, zh | 保留字数在指定范围内的样本 |
| word_repetition_filter | General | en, zh | 保留 word-level n-gram 重复比率在指定范围内的样本 |
## Deduplicator
diff --git a/tests/ops/filter/test_token_num_filter.py b/tests/ops/filter/test_token_num_filter.py
index 514ce21c3..5ee78bab2 100644
--- a/tests/ops/filter/test_token_num_filter.py
+++ b/tests/ops/filter/test_token_num_filter.py
@@ -7,7 +7,7 @@
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
-class WordNumFilterTest(DataJuicerTestCaseBase):
+class TokenNumFilterTest(DataJuicerTestCaseBase):
def test_token_num(self):
src = [
diff --git a/tests/ops/filter/test_word_num_filter.py b/tests/ops/filter/test_word_num_filter.py
index 6a4967f97..6b4522c5e 100644
--- a/tests/ops/filter/test_word_num_filter.py
+++ b/tests/ops/filter/test_word_num_filter.py
@@ -2,14 +2,14 @@
from datasets import Dataset
-from data_juicer.ops.filter.word_num_filter import WordNumFilter
+from data_juicer.ops.filter.words_num_filter import WordsNumFilter
from data_juicer.utils.constant import Fields
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
-class WordNumFilterTest(DataJuicerTestCaseBase):
+class WordsNumFilterTest(DataJuicerTestCaseBase):
- def _run_word_num_filter(self, dataset: Dataset, target_list, op):
+ def _run_words_num_filter(self, dataset: Dataset, target_list, op):
if Fields.stats not in dataset.features:
# TODO:
# this is a temp solution,
@@ -41,8 +41,8 @@ def test_case(self):
'text': 'a v s e c s f e f g a a a '
}]
dataset = Dataset.from_list(ds_list)
- op = WordNumFilter(min_num=5, max_num=15)
- self._run_word_num_filter(dataset, tgt_list, op)
+ op = WordsNumFilter(min_num=5, max_num=15)
+ self._run_words_num_filter(dataset, tgt_list, op)
def test_zh_case(self):
@@ -65,11 +65,11 @@ def test_zh_case(self):
'text': '基于前一步结果,在同一个聚类中找出那些过长文档为假正例,暂不进行滤除'
}]
dataset = Dataset.from_list(ds_list)
- op = WordNumFilter(lang='zh',
+ op = WordsNumFilter(lang='zh',
tokenization=True,
min_num=10,
max_num=25)
- self._run_word_num_filter(dataset, tgt_list, op)
+ self._run_words_num_filter(dataset, tgt_list, op)
if __name__ == '__main__':