diff --git a/configs/config_all.yaml b/configs/config_all.yaml index e090aa99d..feeeecffc 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -183,6 +183,8 @@ process: hf_diffusion: 'stabilityai/stable-diffusion-xl-base-1.0' # model name of the SDXL model on huggingface num_inference_steps: 50 # the larger the value, the better the image generation quality guidance_scale: 7.5 # a higher guidance scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality + text_key_second: None # used to store the first caption in the caption pair + text_key_third: None # used to store the second caption in the caption pair - sentence_split_mapper: # split text to multiple sentences and join them with '\n' lang: 'en' # split text in what language - video_captioning_from_audio_mapper: # caption a video according to its audio streams based on Qwen-Audio model diff --git a/data_juicer/ops/mapper/sdxl_prompt2prompt_mapper.py b/data_juicer/ops/mapper/sdxl_prompt2prompt_mapper.py index 8244a6de0..091fd323a 100644 --- a/data_juicer/ops/mapper/sdxl_prompt2prompt_mapper.py +++ b/data_juicer/ops/mapper/sdxl_prompt2prompt_mapper.py @@ -1,4 +1,5 @@ import abc +import logging from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -14,11 +15,15 @@ from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.model_utils import get_model, prepare_model +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + OP_NAME = 'sdxl_prompt2prompt_mapper' -check_list = ['diffusers', 'torch', 'transformers', 'simhash-pybind'] +check_list = ['diffusers', 'torch', 'transformers'] with AvailabilityChecking(check_list, OP_NAME): import diffusers # noqa: F401 + import transformers # noqa: F401 # avoid hanging when calling stable diffusion in multiprocessing torch.set_num_threads(1) @@ -40,6 +45,8 @@ def __init__( torch_dtype: str = 'fp32', num_inference_steps: float = 50, guidance_scale: float = 7.5, + text_key_second=None, + text_key_third=None, *args, **kwargs): """ @@ -55,6 +62,10 @@ def __init__( :param guidance_scale: A higher guidance scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality. Guidance scale is enabled when + :param text_key_second: used to store the first caption + in the caption pair. + :param text_key_third: used to store the second caption + in the caption pair. """ super().__init__(*args, **kwargs) @@ -68,14 +79,20 @@ def __init__( pretrained_model_name_or_path=hf_diffusion, pipe_func=Prompt2PromptPipeline, torch_dtype=torch_dtype) - self.new_sample_key = ['caption1', 'caption2'] + self.text_key_second = text_key_second + self.text_key_third = text_key_third def process(self, sample, rank=None, context=False): - for temp_new_key in self.new_sample_key: - if temp_new_key not in sample: - raise ValueError( - f'Key \'{temp_new_key}\' is not found in sample. ') + if self.text_key_second is None: + logger.error('This OP (sdxl_prompt2prompt_mapper) requires \ + processing multiple fields, and you need to specify \ + valid `text_key_second`') + + if self.text_key_third is None: + logger.error('This OP (sdxl_prompt2prompt_mapper) requires \ + processing multiple fields, and you need to specify \ + valid `text_key_third`') model = get_model(model_key=self.model_key, rank=rank, @@ -96,7 +113,9 @@ def process(self, sample, rank=None, context=False): sample['images'] = [] with torch.no_grad(): - prompts = [sample['caption1'], sample['caption2']] + prompts = [ + sample[self.text_key_second], sample[self.text_key_third] + ] image = model(prompts, cross_attention_kwargs=cross_attention_kwargs, guidance_scale=self.guidance_scale, diff --git a/tests/ops/mapper/test_sdxl_prompt2prompt_mapper.py b/tests/ops/mapper/test_sdxl_prompt2prompt_mapper.py index 4410f361c..73d0f4e77 100644 --- a/tests/ops/mapper/test_sdxl_prompt2prompt_mapper.py +++ b/tests/ops/mapper/test_sdxl_prompt2prompt_mapper.py @@ -10,17 +10,22 @@ class SDXLPrompt2PromptMapperTest(DataJuicerTestCaseBase): text_key = 'text' + text_key_second = "caption1" + text_key_third = "caption2" + def _run_sdxl_prompt2prompt(self, enable_vllm=False): op = SDXLPrompt2PromptMapper( hf_diffusion='stabilityai/stable-diffusion-xl-base-1.0', - torch_dtype="fp16" + torch_dtype="fp16", + text_key_second=self.text_key_second, + text_key_third=self.text_key_third ) - ds_list = [{"caption1": "a chocolate cake", - "caption2": "a confetti apple cake"}, - {"caption1": "a chocolate", - "caption2": "bread"}] + ds_list = [{self.text_key_second: "a chocolate cake", + self.text_key_third: "a confetti apple bread"}, + {self.text_key_second: "a chocolate", + self.text_key_third: "bread"}] dataset = Dataset.from_list(ds_list) dataset = dataset.map(op.process, num_proc=2, with_rank=True) @@ -36,6 +41,5 @@ def test_sdxl_prompt2prompt(self): self._run_sdxl_prompt2prompt() - if __name__ == '__main__': unittest.main() \ No newline at end of file