Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ready] Add image_segment_mapper #394

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ process:
lang: en # sample in which language
tokenization: false # whether to use model to tokenize documents
substrings: ['http', 'www', '.com', 'href', '//'] # incorrect substrings to remove
- segment_mapper: # perform segment-anything on images and return the bounding box values.
fastsam_path: './FastSAM-x.pt' # model name of the FastSAM model on ultralytics
imgsz: 1024 # image resolution after image resizing
conf: 0.05 # confidence score threshold
iou: 0.5 # IoU (Intersection over Union) score threshold
- sentence_split_mapper: # split text to multiple sentences and join them with '\n'
lang: 'en' # split text in what language
- video_captioning_from_audio_mapper: # caption a video according to its audio streams based on Qwen-Audio model
Expand Down
4 changes: 3 additions & 1 deletion data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
remove_repeat_sentences_mapper, remove_specific_chars_mapper,
remove_table_text_mapper,
remove_words_with_incorrect_substrings_mapper,
replace_content_mapper, sentence_split_mapper,
replace_content_mapper, segment_mapper, sentence_split_mapper,
video_captioning_from_audio_mapper,
video_captioning_from_frames_mapper,
video_captioning_from_summarizer_mapper,
Expand Down Expand Up @@ -54,6 +54,7 @@
from .remove_words_with_incorrect_substrings_mapper import \
RemoveWordsWithIncorrectSubstringsMapper
from .replace_content_mapper import ReplaceContentMapper
from .segment_mapper import SegmentMapper
from .sentence_split_mapper import SentenceSplitMapper
from .video_captioning_from_audio_mapper import VideoCaptioningFromAudioMapper
from .video_captioning_from_frames_mapper import \
Expand Down Expand Up @@ -118,6 +119,7 @@
'AudioFFmpegWrappedMapper',
'VideoSplitByDurationMapper',
'VideoFaceBlurMapper',
'SegmentMapper'
]

# yapf: enable
88 changes: 88 additions & 0 deletions data_juicer/ops/mapper/segment_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import copy

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.ops.op_fusion import LOADED_IMAGES
from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.mm_utils import load_image
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'segment_mapper'

with AvailabilityChecking(['torch', 'transformers', 'simhash-pybind'],
OP_NAME):
import simhash # noqa: F401
import torch
import transformers # noqa: F401
Qirui-jiao marked this conversation as resolved.
Show resolved Hide resolved

# avoid hanging when calling model in multiprocessing
torch.set_num_threads(1)


@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class SegmentMapper(Mapper):
"""Perform segment-anything on images and return the bounding boxes."""

_accelerator = 'cuda'
_batched_op = True

def __init__(self,
fastsam_path='./FastSAM-x.pt',
imgsz=1024,
conf=0.05,
iou=0.5,
*args,
**kwargs):
"""
Initialization method.

:param fastsam_path: location of FastSAM
:param imgsz: image resolution after image resizing
:param conf: confidence score threshold
:param iou: IoU (Intersection over Union) score threshold

"""
super().__init__(*args, **kwargs)

self.model_key = prepare_model(
model_type='fastsam', pretrained_model_name_or_path=fastsam_path)

self.imgsz = imgsz
self.conf = conf
self.iou = iou

def process(self, ori_sample, rank=None):

# there is no image in this sample
if self.image_key not in ori_sample or \
not ori_sample[self.image_key]:
return []

generated_samples = copy.deepcopy(ori_sample)

loaded_image_keys = ori_sample[self.image_key]
images = {}
for loaded_image_key in loaded_image_keys:
if loaded_image_key not in images:
# avoid loading the same images
image = load_image(loaded_image_key)
images[loaded_image_key] = image

model = get_model(self.model_key, rank=rank, use_cuda=self.use_cuda())

generated_samples['bboxes'] = []
Qirui-jiao marked this conversation as resolved.
Show resolved Hide resolved

for image in images:
masks = model([image],
retina_masks=True,
imgsz=self.imgsz,
conf=self.conf,
iou=self.iou,
verbose=False)[0]

if len(masks.boxes.xyxy) == 0:
generated_samples['bboxes'].append([])
else:
generated_samples['bboxes'].append(masks.boxes.xyxy)

return generated_samples
9 changes: 8 additions & 1 deletion data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,12 @@ def prepare_recognizeAnything_model(
return model


def prepare_fastsam_model(pretrained_model_name_or_path):
from ultralytics import FastSAM

return FastSAM(pretrained_model_name_or_path)
Qirui-jiao marked this conversation as resolved.
Show resolved Hide resolved


def prepare_opencv_classifier(model_path):
import cv2
model = cv2.CascadeClassifier(model_path)
Expand All @@ -570,7 +576,8 @@ def prepare_opencv_classifier(model_path):
'diffusion': prepare_diffusion_model,
'video_blip': prepare_video_blip_model,
'recognizeAnything': prepare_recognizeAnything_model,
'opencv_classifier': prepare_opencv_classifier,
'fastsam': prepare_fastsam_model,
'opencv_classifier': prepare_opencv_classifier
}


Expand Down
3 changes: 2 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types.
| Type | Number | Description |
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 43 | Edits and transforms samples |
| [ Mapper ]( #mapper ) | 44 | Edits and transforms samples |
| [ Filter ]( #filter ) | 41 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 4 | Selects top samples based on ranking |
Expand Down Expand Up @@ -77,6 +77,7 @@ All the specific operators are listed below, each featured with several capabili
| remove_table_text_mapper | General, Financial | en | Detects and removes possible table contents (:warning: relies on regular expression matching and thus fragile)|
| remove_words_with_incorrect_<br />substrings_mapper | General | en, zh | Removes words containing specified substrings |
| replace_content_mapper | General | en, zh | Replace all content in the text that matches a specific regular expression pattern with a designated replacement string |
| segment_mapper | Image | - | Perform segment-anything on images and return the bounding box values
| sentence_split_mapper | General | en | Splits and reorganizes sentences according to semantics |
| video_captioning_from_audio_mapper | Multimodal | - | Caption a video according to its audio streams based on Qwen-Audio model |
| video_captioning_from_frames_mapper | Multimodal | - | generate samples whose captions are generated based on an image-to-text model and sampled video frames. Captions from different frames will be concatenated to a single string |
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| 类型 | 数量 | 描述 |
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 |
| [ Mapper ]( #mapper ) | 43 | 对数据样本进行编辑和转换 |
| [ Mapper ]( #mapper ) | 44 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 41 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 |
Expand Down Expand Up @@ -76,6 +76,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| remove_table_text_mapper | General, Financial | en | 检测并删除可能的表格内容(:warning: 依赖正则表达式匹配,因此很脆弱) |
| remove_words_with_incorrect_<br />substrings_mapper | General | en, zh | 删除包含指定子字符串的单词 |
| replace_content_mapper | General | en, zh | 使用一个指定的替换字符串替换文本中满足特定正则表达式模版的所有内容 |
| segment_mapper | Image | - | 对图像实施“分割万物”(segment-anything)的语义分割,并返回bounding box数值 |
| sentence_split_mapper | General | en | 根据语义拆分和重组句子 |
| video_captioning_from_audio_mapper | Multimodal | - | 基于 Qwen-Audio 模型根据视频的音频流为视频生成新的标题描述 |
| video_captioning_from_frames_mapper | Multimodal | - | 生成样本,其标题是基于一个文字生成图片的模型和原始样本视频中指定帧的图像。不同帧产出的标题会拼接为一条单独的字符串。 |
Expand Down
46 changes: 46 additions & 0 deletions tests/ops/mapper/test_segment_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import unittest

from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.utils.mm_utils import SpecialTokens
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
DataJuicerTestCaseBase)

from data_juicer.ops.mapper.segment_mapper import SegmentMapper
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
DataJuicerTestCaseBase)



class SDXLPrompt2PromptMapperTest(DataJuicerTestCaseBase):

text_key = 'text'

def _run_segment_mapper(self, enable_vllm=False):
op = SegmentMapper(
fastsam_path='FastSAM-x.pt',
)

img1_path = './crayon.jpg'
img2_path = './ipod.jpg'
img3_path = './0_19_0_0.jpg'
Qirui-jiao marked this conversation as resolved.
Show resolved Hide resolved

ds_list = [{
'images': [img1_path, img3_path]
}, {
'images': [img2_path]
}]


for sample in ds_list:
result = op.process(sample)
print(f'Output results: {result}')


def test_segment_mapper(self):
self._run_segment_mapper()



if __name__ == '__main__':
unittest.main()
Loading