Skip to content

Commit

Permalink
Dev/fix video filter for bench (#318)
Browse files Browse the repository at this point in the history
* fix video_aspect_ratio_filter

* tensor stats to float

* fix words num filter
  • Loading branch information
BeachWang authored May 24, 2024
1 parent 370e620 commit b5bd283
Show file tree
Hide file tree
Showing 15 changed files with 40 additions and 34 deletions.
7 changes: 4 additions & 3 deletions data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
video_frames_text_similarity_filter, video_motion_score_filter,
video_nsfw_filter, video_ocr_area_ratio_filter,
video_resolution_filter, video_tagging_from_frames_filter,
video_watermark_filter, word_num_filter, word_repetition_filter)
video_watermark_filter, word_repetition_filter,
words_num_filter)
from .alphanumeric_filter import AlphanumericFilter
from .audio_duration_filter import AudioDurationFilter
from .audio_nmf_snr_filter import AudioNMFSNRFilter
Expand Down Expand Up @@ -58,8 +59,8 @@
from .video_resolution_filter import VideoResolutionFilter
from .video_tagging_from_frames_filter import VideoTaggingFromFramesFilter
from .video_watermark_filter import VideoWatermarkFilter
from .word_num_filter import WordNumFilter
from .word_repetition_filter import WordRepetitionFilter
from .words_num_filter import WordsNumFilter

__all__ = [
'ImageTextSimilarityFilter',
Expand Down Expand Up @@ -98,7 +99,7 @@
'SuffixFilter',
'ImageSizeFilter',
'VideoWatermarkFilter',
'WordNumFilter',
'WordsNumFilter',
'ImageFaceRatioFilter',
'FlaggedWordFilter',
'WordRepetitionFilter',
Expand Down
4 changes: 3 additions & 1 deletion data_juicer/ops/filter/image_nsfw_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def compute_stats(self, sample, rank=None, context=False):
inputs = processor(images=images, return_tensors='pt').to(model.device)
outputs = model(**inputs)
logits = outputs.logits
nsfw_scores = [scores[1] for scores in torch.softmax(logits, dim=-1)]
nsfw_scores = [
float(scores[1]) for scores in torch.softmax(logits, dim=-1)
]

sample[Fields.stats][StatsKeys.image_nsfw_score] = nsfw_scores

Expand Down
4 changes: 3 additions & 1 deletion data_juicer/ops/filter/image_watermark_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def compute_stats(self, sample, rank=None, context=False):
inputs = processor(images=images, return_tensors='pt').to(model.device)
outputs = model(**inputs)
logits = outputs.logits
watermark_probs = [probs[1] for probs in torch.softmax(logits, dim=-1)]
watermark_probs = [
float(probs[1]) for probs in torch.softmax(logits, dim=-1)
]

sample[Fields.stats][StatsKeys.image_watermark_prob] = watermark_probs

Expand Down
5 changes: 2 additions & 3 deletions data_juicer/ops/filter/video_aspect_ratio_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,8 @@ def compute_stats(self, sample, context=False):
video_aspect_ratios = {}
for key, video in videos.items():
stream = video.streams.video[0]
video_aspect_ratios[key] = str(
Fraction(stream.codec_context.width,
stream.codec_context.height))
video_aspect_ratios[
key] = stream.codec_context.width / stream.codec_context.height
if not context:
video.close()

Expand Down
7 changes: 4 additions & 3 deletions data_juicer/ops/filter/video_nsfw_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,12 @@ def compute_stats(self, sample, rank=None, context=False):
cur_scores = torch.Tensor(cur_scores)

if self.reduce_mode == 'avg':
nsfw_scores.append(cur_scores.mean())
cur_score = cur_scores.mean()
elif self.reduce_mode == 'max':
nsfw_scores.append(cur_scores.max())
cur_score = cur_scores.max()
else:
nsfw_scores.append(cur_scores.min())
cur_score = cur_scores.min()
nsfw_scores.append(float(cur_score))

sample[Fields.stats][StatsKeys.video_nsfw_score] = nsfw_scores

Expand Down
7 changes: 4 additions & 3 deletions data_juicer/ops/filter/video_watermark_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,12 @@ def compute_stats(self, sample, rank=None, context=False):
cur_probs = torch.Tensor(cur_probs)

if self.reduce_mode == 'avg':
watermark_probs.append(cur_probs.mean())
cur_prob = cur_probs.mean()
elif self.reduce_mode == 'max':
watermark_probs.append(cur_probs.max())
cur_prob = cur_probs.max()
else:
watermark_probs.append(cur_probs.min())
cur_prob = cur_probs.min()
watermark_probs.append(float(cur_prob))

sample[Fields.stats][StatsKeys.video_watermark_prob] = watermark_probs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@OPERATORS.register_module(OP_NAME)
@INTER_WORDS.register_module(OP_NAME)
class WordNumFilter(Filter):
class WordsNumFilter(Filter):
"""Filter to keep samples with total words number within a specific
range."""

Expand Down
4 changes: 2 additions & 2 deletions docs/BadDataExhibition.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ some OPs work well on some datasets but might be useless on others.
| `perplexity_filter` | [LCS-558K](#llava-15-pretrain-dataset), [Books](#books), [ArXiv](#arxiv) |
| `special_characters_filter` | [Wikipedia](#wikipedia), [Books](#books) |
| `text_length_filter` | [Wikipedia](#wikipedia), [ArXiv](#arxiv), [Github Code](#github-code) |
| `word_num_filter` | [Stack Exchange](#stack-exchange) |
| `words_num_filter` | [Stack Exchange](#stack-exchange) |
| `word_repetition_filter` | [MMC4](#mmc4), [Wikipedia](#wikipedia) |

- Everyone from community is welcome to continue to add examples to this table.
Expand Down Expand Up @@ -298,7 +298,7 @@ The pretraining dataset of LLaVA-1.5

| item | value |
|-----------|------------------------------------------------------------------|
| from OP | `word_num_filter` |
| from OP | `words_num_filter` |
| num_words | 2 |
| text | <details><summary>text</summary> <pre>Q: Append</pre> </details> |
| comments | Texts with too few words might be missing content |
Expand Down
4 changes: 2 additions & 2 deletions docs/BadDataExhibition_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
| `perplexity_filter` | [LCS-558K](#llava-15-pretrain-dataset), [Books](#books), [ArXiv](#arxiv) |
| `special_characters_filter` | [Wikipedia](#wikipedia), [Books](#books) |
| `text_length_filter` | [Wikipedia](#wikipedia), [ArXiv](#arxiv), [Github Code](#github-code) |
| `word_num_filter` | [Stack Exchange](#stack-exchange) |
| `words_num_filter` | [Stack Exchange](#stack-exchange) |
| `word_repetition_filter` | [MMC4](#mmc4), [Wikipedia](#wikipedia) |

- 欢迎大家持续补充这个表格。
Expand Down Expand Up @@ -292,7 +292,7 @@ LLaVA-1.5 的预训练数据集。

| 数据项 ||
|-----------|------------------------------------------------------------------|
| 来源算子 | `word_num_filter` |
| 来源算子 | `words_num_filter` |
| num_words | 2 |
| 文本 | <details><summary>text</summary> <pre>Q: Append</pre> </details> |
| 说明 | 单词数量过少的文本可能内容缺失 |
Expand Down
4 changes: 2 additions & 2 deletions docs/DeveloperGuide.md
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ the corresponding documents, including the following docs:
### (Optional) Make your OP fusible

- If the calculation process of some intermediate variables in the new OP is reused in other existing OPs, this new OP can be
added to the fusible OPs to accelerate the whole data processing with OP fusion technology. (e.g. both the `word_num_filter`
added to the fusible OPs to accelerate the whole data processing with OP fusion technology. (e.g. both the `words_num_filter`
and `word_repetition_filter` need to split the input text into words)
- When opening OP fusion, these reused calculation processes and intermediate variables can be shared in the `context` between
OPs, thus reducing repeated calculations.
Expand Down Expand Up @@ -332,7 +332,7 @@ to this intermediate variable, indicating that the intermediate variable may be
...
@OPERATORS.register_module(OP_NAME)
@INTER_WORDS.register_module(OP_NAME) # register this new OP into the registry group
class WordNumFilter(Filter):
class WordsNumFilter(Filter):
...
```

Expand Down
6 changes: 3 additions & 3 deletions docs/DeveloperGuide_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,9 @@ if __name__ == '__main__':

### (可选)使新算子可以进行算子融合

- 如果我们的新算子中的部分中间变量的计算过程与已有的算子重复,那么可以将其添加到可融合算子中,以在数据处理时利用算子融合进行加速。(如`word_num_filter`与`word_repetition_filter`都需要对输入文本进行分词)
- 如果我们的新算子中的部分中间变量的计算过程与已有的算子重复,那么可以将其添加到可融合算子中,以在数据处理时利用算子融合进行加速。(如`words_num_filter`与`word_repetition_filter`都需要对输入文本进行分词)
- 当算子融合(OP Fusion)功能开启时,这些重复的计算过程和中间变量是可以在算子之间的`context`中共享的,从而可以减少重复计算。
- 可通过如下步骤使包含共有中间变量的算子可进行算子融合(以`word_num_filter`算子为例)。
- 可通过如下步骤使包含共有中间变量的算子可进行算子融合(以`words_num_filter`算子为例)。

1. (可选)如果新算子中产生了新的中间变量,需要在`utils/constant.py`中的`InterVars`类中添加新的中间变量名称。通常需要在名称前加上`DEFAULT_PREFIX`前缀。

Expand Down Expand Up @@ -313,7 +313,7 @@ ALL_INTER_VARS = [INTER_LINES, INTER_WORDS, LOADED_IMAGES] # 并添加到注册
...
@OPERATORS.register_module(OP_NAME)
@INTER_WORDS.register_module(OP_NAME) # 将该算子注册到注册组中
class WordNumFilter(Filter):
class WordsNumFilter(Filter):
...
```

Expand Down
2 changes: 1 addition & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ All the specific operators are listed below, each featured with several capabili
| video_resolution_filter | Video | - | Keeps samples containing videos with horizontal and vertical resolutions within the specified range |
| video_watermark_filter | Video | - | Keeps samples containing videos with predicted watermark probabilities below the threshold |
| video_tagging_from_frames_filter | Video | - | Keep samples containing videos with given tags |
| word_num_filter | General | en, zh | Keeps samples with word count within the specified range |
| words_num_filter | General | en, zh | Keeps samples with word count within the specified range |
| word_repetition_filter | General | en, zh | Keeps samples with word-level n-gram repetition ratio within the specified range |


Expand Down
2 changes: 1 addition & 1 deletion docs/Operators_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| video_resolution_filter | Video | - | 保留包含视频的分辨率(包括横向分辨率和纵向分辨率)在指定范围内的样本 |
| video_watermark_filter | Video | - | 保留包含视频有水印的概率在指定阈值之下的样本|
| video_tagging_from_frames_filter | Video | - | 保留包含具有给定标签视频的样本 |
| word_num_filter | General | en, zh | 保留字数在指定范围内的样本 |
| words_num_filter | General | en, zh | 保留字数在指定范围内的样本 |
| word_repetition_filter | General | en, zh | 保留 word-level n-gram 重复比率在指定范围内的样本 |

## Deduplicator <a name="deduplicator"/>
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/filter/test_token_num_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase


class WordNumFilterTest(DataJuicerTestCaseBase):
class TokenNumFilterTest(DataJuicerTestCaseBase):

def test_token_num(self):
src = [
Expand Down
14 changes: 7 additions & 7 deletions tests/ops/filter/test_word_num_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from datasets import Dataset

from data_juicer.ops.filter.word_num_filter import WordNumFilter
from data_juicer.ops.filter.words_num_filter import WordsNumFilter
from data_juicer.utils.constant import Fields
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase


class WordNumFilterTest(DataJuicerTestCaseBase):
class WordsNumFilterTest(DataJuicerTestCaseBase):

def _run_word_num_filter(self, dataset: Dataset, target_list, op):
def _run_words_num_filter(self, dataset: Dataset, target_list, op):
if Fields.stats not in dataset.features:
# TODO:
# this is a temp solution,
Expand Down Expand Up @@ -41,8 +41,8 @@ def test_case(self):
'text': 'a v s e c s f e f g a a a '
}]
dataset = Dataset.from_list(ds_list)
op = WordNumFilter(min_num=5, max_num=15)
self._run_word_num_filter(dataset, tgt_list, op)
op = WordsNumFilter(min_num=5, max_num=15)
self._run_words_num_filter(dataset, tgt_list, op)

def test_zh_case(self):

Expand All @@ -65,11 +65,11 @@ def test_zh_case(self):
'text': '基于前一步结果,在同一个聚类中找出那些过长文档为假正例,暂不进行滤除'
}]
dataset = Dataset.from_list(ds_list)
op = WordNumFilter(lang='zh',
op = WordsNumFilter(lang='zh',
tokenization=True,
min_num=10,
max_num=25)
self._run_word_num_filter(dataset, tgt_list, op)
self._run_words_num_filter(dataset, tgt_list, op)


if __name__ == '__main__':
Expand Down

0 comments on commit b5bd283

Please sign in to comment.