Skip to content

Commit

Permalink
demo refactor for clip
Browse files Browse the repository at this point in the history
  • Loading branch information
drcege committed Jan 2, 2024
1 parent b8dadc7 commit 1cd59f0
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 7 deletions.
12 changes: 7 additions & 5 deletions data_juicer/ops/filter/image_text_similarity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import (SpecialTokens, load_image,
remove_special_tokens)
from data_juicer.utils.model_utils import get_model, prepare_model
from data_juicer.utils.model_utils import _get_model as get_model
from data_juicer.utils.model_utils import _prepare_model as prepare_model
from data_juicer.utils.model_utils import huggingface_clip

from ..base_op import OPERATORS, Filter
from ..op_fusion import LOADED_IMAGES
Expand Down Expand Up @@ -69,12 +71,12 @@ def __init__(self,
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_key = prepare_model(model_type='hf_clip', model_key=hf_clip)
self.model_key = prepare_model(huggingface_clip, clip_name=hf_clip)
self.reduce_mode = reduce_mode
self.horizontal_flip = horizontal_flip
self.vertical_flip = vertical_flip

def compute_stats(self, sample, context=False):
def compute_stats(self, sample, rank, context=False):
# check if it's computed already
if StatsKeys.image_text_similarity in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -105,7 +107,7 @@ def compute_stats(self, sample, context=False):
text = sample[self.text_key]
offset = 0
similarity = []
model, processor = get_model(self.model_key)
model, processor = get_model(self.model_key, rank)

for chunk in text.split(SpecialTokens.eoc):
count = chunk.count(SpecialTokens.image)
Expand All @@ -130,7 +132,7 @@ def compute_stats(self, sample, context=False):
truncation=True,
max_length=model.config.text_config.
max_position_embeddings,
padding=True)
padding=True).to(model.device)

outputs = model(**inputs)
chunk_logits = outputs.logits_per_text.detach().cpu() / 100.0
Expand Down
23 changes: 23 additions & 0 deletions data_juicer/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import multiprocess as mp
import torch
from loguru import logger


def set_mp_start_method(method=None):
if torch.cuda.is_available():
desired_method = 'spawn'
else:
desired_method = method or mp.get_start_method(
allow_none=True) or 'fork'

try:
mp.set_start_method(desired_method, force=True)
logger.info(
f"Setting multiprocess start method to '{desired_method}'.")
except RuntimeError as e:
logger.error(f'Error setting multiprocess start method: {e}')


def initialize(**kw_args):
mp_start_method = kw_args.pop('mp_start_method', None)
set_mp_start_method(mp_start_method)
31 changes: 31 additions & 0 deletions data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
from functools import partial

import multiprocess as mp
import torch
import wget
from loguru import logger

Expand Down Expand Up @@ -186,6 +189,9 @@ def prepare_huggingface_clip(clip_name):
return (model, processor)


huggingface_clip = prepare_huggingface_clip


def prepare_huggingface_blip(blip_name):
"""
Prepare and load a blip and processor from HuggingFace.
Expand Down Expand Up @@ -290,3 +296,28 @@ def get_model(model_key, lang='en', model_type='sentencepiece'):
if model_key not in MODEL_ZOO:
prepare_model(lang=lang, model_type=model_type, model_key=model_key)
return MODEL_ZOO.get(model_key, None)


def _prepare_model(model_func, **model_kwargs):
global MODEL_ZOO
func = partial(model_func, **model_kwargs)
if mp.get_start_method() == 'fork':
model = func()
MODEL_ZOO[func] = model
return func


def _get_model(model_key, rank=-1):
global MODEL_ZOO
if model_key not in MODEL_ZOO:
MODEL_ZOO[model_key] = model_key()
if torch.cuda.is_available() and isinstance(rank, int):
for index, module in enumerate(MODEL_ZOO[model_key]):
if callable(getattr(module, 'to', None)):
logger.info(f'Move {module.__class__} to cuda:{rank}')
module.to(f'cuda:{rank}')
ref_module = MODEL_ZOO[model_key][index]
logger.debug(
f'{ref_module.__class__} in device {ref_module.device}'
)
return MODEL_ZOO[model_key]
8 changes: 6 additions & 2 deletions tests/ops/filter/test_image_text_similarity_filter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import unittest

import torch
from datasets import Dataset

from data_juicer.ops.filter.image_text_similarity_filter import ImageTextSimilarityFilter
from data_juicer.ops.filter.image_text_similarity_filter import \
ImageTextSimilarityFilter
from data_juicer.utils import initialize
from data_juicer.utils.constant import Fields
from data_juicer.utils.mm_utils import SpecialTokens

Expand All @@ -26,7 +29,7 @@ def _run_filter(self, dataset: Dataset, target_list, op, num_proc=1):
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)

dataset = dataset.map(op.compute_stats, num_proc=num_proc)
dataset = dataset.map(op.compute_stats, num_proc=num_proc, with_rank=True)
dataset = dataset.filter(op.process, num_proc=num_proc)
dataset = dataset.select_columns(column_names=['text', 'images'])
res_list = dataset.to_list()
Expand Down Expand Up @@ -274,4 +277,5 @@ def test_multi_process(self):


if __name__ == '__main__':
initialize()
unittest.main()

0 comments on commit 1cd59f0

Please sign in to comment.