Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Qirui-jiao committed Sep 2, 2024
1 parent 786c856 commit 1f62278
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 13 deletions.
2 changes: 2 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ process:
hf_diffusion: 'stabilityai/stable-diffusion-xl-base-1.0' # model name of the SDXL model on huggingface
num_inference_steps: 50 # the larger the value, the better the image generation quality
guidance_scale: 7.5 # a higher guidance scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality
text_key_second: None # used to store the first caption in the caption pair
text_key_third: None # used to store the second caption in the caption pair
- sentence_split_mapper: # split text to multiple sentences and join them with '\n'
lang: 'en' # split text in what language
- video_captioning_from_audio_mapper: # caption a video according to its audio streams based on Qwen-Audio model
Expand Down
33 changes: 26 additions & 7 deletions data_juicer/ops/mapper/sdxl_prompt2prompt_mapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
Expand All @@ -14,11 +15,15 @@
from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.model_utils import get_model, prepare_model

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

OP_NAME = 'sdxl_prompt2prompt_mapper'

check_list = ['diffusers', 'torch', 'transformers', 'simhash-pybind']
check_list = ['diffusers', 'torch', 'transformers']
with AvailabilityChecking(check_list, OP_NAME):
import diffusers # noqa: F401
import transformers # noqa: F401

# avoid hanging when calling stable diffusion in multiprocessing
torch.set_num_threads(1)
Expand All @@ -40,6 +45,8 @@ def __init__(
torch_dtype: str = 'fp32',
num_inference_steps: float = 50,
guidance_scale: float = 7.5,
text_key_second=None,
text_key_third=None,
*args,
**kwargs):
"""
Expand All @@ -55,6 +62,10 @@ def __init__(
:param guidance_scale: A higher guidance scale value encourages the
model to generate images closely linked to the text prompt at the
expense of lower image quality. Guidance scale is enabled when
:param text_key_second: used to store the first caption
in the caption pair.
:param text_key_third: used to store the second caption
in the caption pair.
"""
super().__init__(*args, **kwargs)
Expand All @@ -68,14 +79,20 @@ def __init__(
pretrained_model_name_or_path=hf_diffusion,
pipe_func=Prompt2PromptPipeline,
torch_dtype=torch_dtype)
self.new_sample_key = ['caption1', 'caption2']
self.text_key_second = text_key_second
self.text_key_third = text_key_third

def process(self, sample, rank=None, context=False):

for temp_new_key in self.new_sample_key:
if temp_new_key not in sample:
raise ValueError(
f'Key \'{temp_new_key}\' is not found in sample. ')
if self.text_key_second is None:
logger.error('This OP (sdxl_prompt2prompt_mapper) requires \
processing multiple fields, and you need to specify \
valid `text_key_second`')

if self.text_key_third is None:
logger.error('This OP (sdxl_prompt2prompt_mapper) requires \
processing multiple fields, and you need to specify \
valid `text_key_third`')

model = get_model(model_key=self.model_key,
rank=rank,
Expand All @@ -96,7 +113,9 @@ def process(self, sample, rank=None, context=False):
sample['images'] = []

with torch.no_grad():
prompts = [sample['caption1'], sample['caption2']]
prompts = [
sample[self.text_key_second], sample[self.text_key_third]
]
image = model(prompts,
cross_attention_kwargs=cross_attention_kwargs,
guidance_scale=self.guidance_scale,
Expand Down
16 changes: 10 additions & 6 deletions tests/ops/mapper/test_sdxl_prompt2prompt_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,22 @@ class SDXLPrompt2PromptMapperTest(DataJuicerTestCaseBase):

text_key = 'text'

text_key_second = "caption1"
text_key_third = "caption2"

def _run_sdxl_prompt2prompt(self, enable_vllm=False):
op = SDXLPrompt2PromptMapper(
hf_diffusion='stabilityai/stable-diffusion-xl-base-1.0',
torch_dtype="fp16"
torch_dtype="fp16",
text_key_second=self.text_key_second,
text_key_third=self.text_key_third
)


ds_list = [{"caption1": "a chocolate cake",
"caption2": "a confetti apple cake"},
{"caption1": "a chocolate",
"caption2": "bread"}]
ds_list = [{self.text_key_second: "a chocolate cake",
self.text_key_third: "a confetti apple bread"},
{self.text_key_second: "a chocolate",
self.text_key_third: "bread"}]

dataset = Dataset.from_list(ds_list)
dataset = dataset.map(op.process, num_proc=2, with_rank=True)
Expand All @@ -36,6 +41,5 @@ def test_sdxl_prompt2prompt(self):
self._run_sdxl_prompt2prompt()



if __name__ == '__main__':
unittest.main()

0 comments on commit 1f62278

Please sign in to comment.