diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml
index f7d9dde5258d..ff687db5c18b 100644
--- a/.github/workflows/pr_tests.yml
+++ b/.github/workflows/pr_tests.yml
@@ -115,7 +115,7 @@ jobs:
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
- examples/test_examples.py
+ examples
- name: Failure short reports
if: ${{ failure() }}
diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml
index 798fa777c6c6..6ea873d0a79c 100644
--- a/.github/workflows/push_tests_fast.yml
+++ b/.github/workflows/push_tests_fast.yml
@@ -100,7 +100,7 @@ jobs:
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
- examples/test_examples.py
+ examples
- name: Failure short reports
if: ${{ failure() }}
diff --git a/PHILOSOPHY.md b/PHILOSOPHY.md
index 38c735480664..8baf11103d84 100644
--- a/PHILOSOPHY.md
+++ b/PHILOSOPHY.md
@@ -82,7 +82,7 @@ Models are designed as configurable toolboxes that are natural extensions of [Py
The following design principles are followed:
- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py), etc...
-- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modelling files and shows that models do not really follow the single-file policy.
+- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.
- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.
- Models all inherit from `ModelMixin` and `ConfigMixin`.
- Models can be optimized for performance when it doesnβt demand major code changes, keep backward compatibility, and give significant memory or compute gain.
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index ec901085d931..7b7779dcb035 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -72,6 +72,8 @@
title: Overview
- local: using-diffusers/sdxl
title: Stable Diffusion XL
+ - local: using-diffusers/sdxl_turbo
+ title: SDXL Turbo
- local: using-diffusers/kandinsky
title: Kandinsky
- local: using-diffusers/controlnet
@@ -94,6 +96,8 @@
title: Latent Consistency Model-LoRA
- local: using-diffusers/inference_with_lcm
title: Latent Consistency Model
+ - local: using-diffusers/svd
+ title: Stable Video Diffusion
title: Specific pipeline examples
- sections:
- local: training/overview
@@ -129,6 +133,8 @@
title: LoRA
- local: training/custom_diffusion
title: Custom Diffusion
+ - local: training/lcm_distill
+ title: Latent Consistency Distillation
- local: training/ddpo
title: Reinforcement learning training with DDPO
title: Methods
@@ -329,6 +335,8 @@
title: Stable Diffusion 2
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
title: Stable Diffusion XL
+ - local: api/pipelines/stable_diffusion/sdxl_turbo
+ title: SDXL Turbo
- local: api/pipelines/stable_diffusion/latent_upscale
title: Latent upscaler
- local: api/pipelines/stable_diffusion/upscale
diff --git a/docs/source/en/api/pipelines/kandinsky3.md b/docs/source/en/api/pipelines/kandinsky3.md
index cc4f87d47f58..ff57843ec052 100644
--- a/docs/source/en/api/pipelines/kandinsky3.md
+++ b/docs/source/en/api/pipelines/kandinsky3.md
@@ -9,7 +9,32 @@ specific language governing permissions and limitations under the License.
# Kandinsky 3
-TODO
+Kandinsky 3 is created by [Vladimir Arkhipkin](https://github.com/oriBetelgeuse),[Anastasia Maltseva](https://github.com/NastyaMittseva),[Igor Pavlov](https://github.com/boomb0om),[Andrei Filatov](https://github.com/anvilarth),[Arseniy Shakhmatov](https://github.com/cene555),[Andrey Kuznetsov](https://github.com/kuznetsoffandrey),[Denis Dimitrov](https://github.com/denndimitrov), [Zein Shaheen](https://github.com/zeinsh)
+
+The description from it's Github page:
+
+*Kandinsky 3.0 is an open-source text-to-image diffusion model built upon the Kandinsky2-x model family. In comparison to its predecessors, enhancements have been made to the text understanding and visual quality of the model, achieved by increasing the size of the text encoder and Diffusion U-Net models, respectively.*
+
+Its architecture includes 3 main components:
+1. [FLAN-UL2](https://huggingface.co/google/flan-ul2), which is an encoder decoder model based on the T5 architecture.
+2. New U-Net architecture featuring BigGAN-deep blocks doubles depth while maintaining the same number of parameters.
+3. Sber-MoVQGAN is a decoder proven to have superior results in image restoration.
+
+
+
+The original codebase can be found at [ai-forever/Kandinsky-3](https://github.com/ai-forever/Kandinsky-3).
+
+
+
+Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.
+
+
+
+
+
+Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
## Kandinsky3Pipeline
diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md
index ef864c6f8a2e..7dab22469dc2 100644
--- a/docs/source/en/api/pipelines/overview.md
+++ b/docs/source/en/api/pipelines/overview.md
@@ -51,6 +51,7 @@ The table below lists all the pipelines currently available in π€ Diffusers an
| [InstructPix2Pix](pix2pix) | image editing |
| [Kandinsky 2.1](kandinsky) | text2image, image2image, inpainting, interpolation |
| [Kandinsky 2.2](kandinsky_v22) | text2image, image2image, inpainting |
+| [Kandinsky 3](kandinsky3) | text2image, image2image |
| [Latent Consistency Models](latent_consistency_models) | text2image |
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
| [LDM3D](stable_diffusion/ldm3d_diffusion) | text2image, text-to-3D, text-to-pano, upscaling |
diff --git a/docs/source/en/api/pipelines/stable_diffusion/sdxl_turbo.md b/docs/source/en/api/pipelines/stable_diffusion/sdxl_turbo.md
new file mode 100644
index 000000000000..cb5b84ee06df
--- /dev/null
+++ b/docs/source/en/api/pipelines/stable_diffusion/sdxl_turbo.md
@@ -0,0 +1,53 @@
+
+
+# SDXL Turbo
+
+Stable Diffusion XL (SDXL) Turbo was proposed in [Adversarial Diffusion Distillation](https://stability.ai/research/adversarial-diffusion-distillation) by Axel Sauer, Dominik Lorenz, Andreas Blattmann, and Robin Rombach.
+
+The abstract from the paper is:
+
+*We introduce Adversarial Diffusion Distillation (ADD), a novel training approach that efficiently samples large-scale foundational image diffusion models in just 1β4 steps while maintaining high image quality. We use score distillation to leverage large-scale off-the-shelf image diffusion models as a teacher signal in combination with an adversarial loss to ensure high image fidelity even in the low-step regime of one or two sampling steps. Our analyses show that our model clearly outperforms existing few-step methods (GANs,Latent Consistency Models) in a single step and reaches the performance of state-of-the-art diffusion models (SDXL) in only four steps. ADD is the first method to unlock single-step, real-time image synthesis with foundation models.*
+
+## Tips
+
+- SDXL Turbo uses the exact same architecture as [SDXL](./stable_diffusion_xl).
+- SDXL Turbo should disable guidance scale by setting `guidance_scale=0.0`
+- SDXL Turbo should use `timestep_spacing='trailing'` for the scheduler and use between 1 and 4 steps.
+- SDXL Turbo has been trained to generate images of size 512x512.
+- SDXL Turbo is open-access, but not open-source meaning that one might have to buy a model license in order to use it for commercial applications. Make sure to read the [official model card](https://huggingface.co/stabilityai/sdxl-turbo) to learn more.
+
+
+
+To learn how to use SDXL Turbo for various tasks, how to optimize performance, and other usage examples, take a look at the [Stable Diffusion XL](../../../using-diffusers/sdxl_turbo) guide.
+
+Check out the [Stability AI](https://huggingface.co/stabilityai) Hub organization for the official base and refiner model checkpoints!
+
+
+
+## StableDiffusionXLPipeline
+
+[[autodoc]] StableDiffusionXLPipeline
+ - all
+ - __call__
+
+## StableDiffusionXLImg2ImgPipeline
+
+[[autodoc]] StableDiffusionXLImg2ImgPipeline
+ - all
+ - __call__
+
+## StableDiffusionXLInpaintPipeline
+
+[[autodoc]] StableDiffusionXLInpaintPipeline
+ - all
+ - __call__
diff --git a/docs/source/en/api/pipelines/text_to_video_zero.md b/docs/source/en/api/pipelines/text_to_video_zero.md
index 626e75f94936..f58a151c2a51 100644
--- a/docs/source/en/api/pipelines/text_to_video_zero.md
+++ b/docs/source/en/api/pipelines/text_to_video_zero.md
@@ -92,6 +92,19 @@ imageio.mimsave("video.mp4", result, fps=4)
```
+- #### SDXL Support
+In order to use the SDXL model when generating a video from prompt, use the `TextToVideoZeroSDXLPipeline` pipeline:
+
+```python
+import torch
+from diffusers import TextToVideoZeroSDXLPipeline
+
+model_id = "stabilityai/stable-diffusion-xl-base-1.0"
+pipe = TextToVideoZeroSDXLPipeline.from_pretrained(
+ model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
+```
+
### Text-To-Video with Pose Control
To generate a video from prompt with additional pose control
@@ -141,7 +154,33 @@ To generate a video from prompt with additional pose control
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
imageio.mimsave("video.mp4", result, fps=4)
```
-
+- #### SDXL Support
+
+ Since our attention processor also works with SDXL, it can be utilized to generate a video from prompt using ControlNet models powered by SDXL:
+ ```python
+ import torch
+ from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
+ from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
+
+ controlnet_model_id = 'thibaud/controlnet-openpose-sdxl-1.0'
+ model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
+
+ controlnet = ControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch.float16)
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ model_id, controlnet=controlnet, torch_dtype=torch.float16
+ ).to('cuda')
+
+ # Set the attention processor
+ pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
+ pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
+
+ # fix latents for all frames
+ latents = torch.randn((1, 4, 128, 128), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)
+
+ prompt = "Darth Vader dancing in a desert"
+ result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
+ imageio.mimsave("video.mp4", result, fps=4)
+ ```
### Text-To-Video with Edge Control
@@ -253,5 +292,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- all
- __call__
+## TextToVideoZeroSDXLPipeline
+[[autodoc]] TextToVideoZeroSDXLPipeline
+ - all
+ - __call__
+
## TextToVideoPipelineOutput
[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput
diff --git a/docs/source/en/training/lcm_distill.md b/docs/source/en/training/lcm_distill.md
new file mode 100644
index 000000000000..e83400651584
--- /dev/null
+++ b/docs/source/en/training/lcm_distill.md
@@ -0,0 +1,255 @@
+
+
+# Latent Consistency Distillation
+
+[Latent Consistency Models (LCMs)](https://hf.co/papers/2310.04378) are able to generate high-quality images in just a few steps, representing a big leap forward because many pipelines require at least 25+ steps. LCMs are produced by applying the latent consistency distillation method to any Stable Diffusion model. This method works by applying *one-stage guided distillation* to the latent space, and incorporating a *skipping-step* method to consistently skip timesteps to accelerate the distillation process (refer to section 4.1, 4.2, and 4.3 of the paper for more details).
+
+If you're training on a GPU with limited vRAM, try enabling `gradient_checkpointing`, `gradient_accumulation_steps`, and `mixed_precision` to reduce memory-usage and speedup training. You can reduce your memory-usage even more by enabling memory-efficient attention with [xFormers](../optimization/xformers) and [bitsandbytes'](https://github.com/TimDettmers/bitsandbytes) 8-bit optimizer.
+
+This guide will explore the [train_lcm_distill_sd_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sd_wds.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.
+
+Before running the script, make sure you install the library from source:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
+
+```bash
+cd examples/consistency_distillation
+pip install -r requirements.txt
+```
+
+
+
+π€ Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the π€ Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
+
+
+
+Initialize an π€ Accelerate environment (try enabling `torch.compile` to significantly speedup training):
+
+```bash
+accelerate config
+```
+
+To setup a default π€ Accelerate environment without choosing any configurations:
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell, like a notebook, you can use:
+
+```bash
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
+
+## Script parameters
+
+
+
+The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sd_wds.py) and let us know if you have any questions or concerns.
+
+
+
+The training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L419) function. This function provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
+
+For example, to speedup training with mixed precision using the fp16 format, add the `--mixed_precision` parameter to the training command:
+
+```bash
+accelerate launch train_lcm_distill_sd_wds.py \
+ --mixed_precision="fp16"
+```
+
+Most of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so you'll focus on the parameters that are relevant to latent consistency distillation in this guide.
+
+- `--pretrained_teacher_model`: the path to a pretrained latent diffusion model to use as the teacher model
+- `--pretrained_vae_model_name_or_path`: path to a pretrained VAE; the SDXL VAE is known to suffer from numerical instability, so this parameter allows you to specify an alternative VAE (like this [VAE]((https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)) by madebyollin which works in fp16)
+- `--w_min` and `--w_max`: the minimum and maximum guidance scale values for guidance scale sampling
+- `--num_ddim_timesteps`: the number of timesteps for DDIM sampling
+- `--loss_type`: the type of loss (L2 or Huber) to calculate for latent consistency distillation; Huber loss is generally preferred because it's more robust to outliers
+- `--huber_c`: the Huber loss parameter
+
+## Training script
+
+The training script starts by creating a dataset class - [`Text2ImageDataset`](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L141) - for preprocessing the images and creating a training dataset.
+
+```py
+def transform(example):
+ image = example["image"]
+ image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = TF.crop(image, c_top, c_left, resolution, resolution)
+ image = TF.to_tensor(image)
+ image = TF.normalize(image, [0.5], [0.5])
+
+ example["image"] = image
+ return example
+```
+
+For improved performance on reading and writing large datasets stored in the cloud, this script uses the [WebDataset](https://github.com/webdataset/webdataset) format to create a preprocessing pipeline to apply transforms and create a dataset and dataloader for training. Images are processed and fed to the training loop without having to download the full dataset first.
+
+```py
+processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue),
+ wds.map(filter_keys({"image", "text"})),
+ wds.map(transform),
+ wds.to_tuple("image", "text"),
+]
+```
+
+In the [`main()`](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L768) function, all the necessary components like the noise scheduler, tokenizers, text encoders, and VAE are loaded. The teacher UNet is also loaded here and then you can create a student UNet from the teacher UNet. The student UNet is updated by the optimizer during training.
+
+```py
+teacher_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+)
+
+unet = UNet2DConditionModel(**teacher_unet.config)
+unet.load_state_dict(teacher_unet.state_dict(), strict=False)
+unet.train()
+```
+
+Now you can create the [optimizer](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L979) to update the UNet parameters:
+
+```py
+optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+Create the [dataset](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L994):
+
+```py
+dataset = Text2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+)
+train_dataloader = dataset.train_dataloader
+```
+
+Next, you're ready to setup the [training loop](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L1049) and implement the latent consistency distillation method (see Algorithm 1 in the paper for more details). This section of the script takes care of adding noise to the latents, sampling and creating a guidance scale embedding, and predicting the original image from the noise.
+
+```py
+pred_x_0 = predicted_origin(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+)
+
+model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+```
+
+It gets the [teacher model predictions](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L1172) and the [LCM predictions](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L1209) next, calculates the loss, and then backpropagates it to the LCM.
+
+```py
+if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+```
+
+If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers tutorial](../using-diffusers/write_own_pipeline) which breaks down the basic pattern of the denoising process.
+
+## Launch the script
+
+Now you're ready to launch the training script and start distilling!
+
+For this guide, you'll use the `--train_shards_path_or_url` to specify the path to the [Conceptual Captions 12M](https://github.com/google-research-datasets/conceptual-12m) dataset stored on the Hub [here](https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset). Set the `MODEL_DIR` environment variable to the name of the teacher model and `OUTPUT_DIR` to where you want to save the model.
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path/to/saved/model"
+
+accelerate launch train_lcm_distill_sd_wds.py \
+ --pretrained_teacher_model=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision=fp16 \
+ --resolution=512 \
+ --learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
+ --max_train_steps=1000 \
+ --max_train_samples=4000000 \
+ --dataloader_num_workers=8 \
+ --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
+ --validation_steps=200 \
+ --checkpointing_steps=200 --checkpoints_total_limit=10 \
+ --train_batch_size=12 \
+ --gradient_checkpointing --enable_xformers_memory_efficient_attention \
+ --gradient_accumulation_steps=1 \
+ --use_8bit_adam \
+ --resume_from_checkpoint=latest \
+ --report_to=wandb \
+ --seed=453645634 \
+ --push_to_hub
+```
+
+Once training is complete, you can use your new LCM for inference.
+
+```py
+from diffusers import UNet2DConditionModel, DiffusionPipeline, LCMScheduler
+import torch
+
+unet = UNet2DConditionModel.from_pretrained("your-username/your-model", torch_dtype=torch.float16, variant="fp16")
+pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", unet=unet, torch_dtype=torch.float16, variant="fp16")
+
+pipeline.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+pipeline.to("cuda")
+
+prompt = "sushi rolls in the form of panda heads, sushi platter"
+
+image = pipeline(prompt, num_inference_steps=4, guidance_scale=1.0).images[0]
+```
+
+## LoRA
+
+LoRA is a training technique for significantly reducing the number of trainable parameters. As a result, training is faster and it is easier to store the resulting weights because they are a lot smaller (~100MBs). Use the [train_lcm_distill_lora_sd_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py) or [train_lcm_distill_lora_sdxl.wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py) script to train with LoRA.
+
+The LoRA training script is discussed in more detail in the [LoRA training](lora) guide.
+
+## Stable Diffusion XL
+
+Stable Diffusion XL (SDXL) is a powerful text-to-image model that generates high-resolution images, and it adds a second text-encoder to its architecture. Use the [train_lcm_distill_sdxl_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py) script to train a SDXL model with LoRA.
+
+The SDXL training script is discussed in more detail in the [SDXL training](sdxl) guide.
+
+## Next steps
+
+Congratulations on distilling a LCM model! To learn more about LCM, the following may be helpful:
+
+- Learn how to use [LCMs for inference](../using-diffusers/lcm) for text-to-image, image-to-image, and with LoRA checkpoints.
+- Read the [SDXL in 4 steps with Latent Consistency LoRAs](https://huggingface.co/blog/lcm_lora) blog post to learn more about SDXL LCM-LoRA's for super fast inference, quality comparisons, benchmarks, and more.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/sdxl_turbo.md b/docs/source/en/using-diffusers/sdxl_turbo.md
new file mode 100644
index 000000000000..99e1c7000e3f
--- /dev/null
+++ b/docs/source/en/using-diffusers/sdxl_turbo.md
@@ -0,0 +1,116 @@
+
+
+# Stable Diffusion XL Turbo
+
+[[open-in-colab]]
+
+SDXL Turbo is an adversarial time-distilled [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) (SDXL) model capable
+of running inference in as little as 1 step.
+
+This guide will show you how to use SDXL-Turbo for text-to-image and image-to-image.
+
+Before you begin, make sure you have the following libraries installed:
+
+```py
+# uncomment to install the necessary libraries in Colab
+#!pip install -q diffusers transformers accelerate omegaconf
+```
+
+## Load model checkpoints
+
+Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~StableDiffusionXLPipeline.from_pretrained`] method:
+
+```py
+from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
+pipeline = pipeline.to("cuda")
+```
+
+You can also use the [`~StableDiffusionXLPipeline.from_single_file`] method to load a model checkpoint stored in a single file format (`.ckpt` or `.safetensors`) from the Hub or locally:
+
+```py
+from diffusers import StableDiffusionXLPipeline
+import torch
+
+pipeline = StableDiffusionXLPipeline.from_single_file(
+ "https://huggingface.co/stabilityai/sdxl-turbo/blob/main/sd_xl_turbo_1.0_fp16.safetensors", torch_dtype=torch.float16)
+pipeline = pipeline.to("cuda")
+```
+
+## Text-to-image
+
+For text-to-image, pass a text prompt. By default, SDXL Turbo generates a 512x512 image, and that resolution gives the best results. You can try setting the `height` and `width` parameters to 768x768 or 1024x1024, but you should expect quality degradations when doing so.
+
+Make sure to set `guidance_scale` to 0.0 to disable, as the model was trained without it. A single inference step is enough to generate high quality images.
+Increasing the number of steps to 2, 3 or 4 should improve image quality.
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline_text2image = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
+pipeline_text2image = pipeline_text2image.to("cuda")
+
+prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
+
+image = pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=1).images[0]
+image
+```
+
+
+
+
+
+## Image-to-image
+
+For image-to-image generation, make sure that `num_inference_steps * strength` is larger or equal to 1.
+The image-to-image pipeline will run for `int(num_inference_steps * strength)` steps, e.g. `0.5 * 2.0 = 1` step in
+our example below.
+
+```py
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import load_image, make_image_grid
+
+# use from_pipe to avoid consuming additional memory when loading a checkpoint
+pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_text2image).to("cuda")
+
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
+init_image = init_image.resize((512, 512))
+
+prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
+
+image = pipeline(prompt, image=init_image, strength=0.5, guidance_scale=0.0, num_inference_steps=2).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+
+## Speed-up SDXL Turbo even more
+
+- Compile the UNet if you are using PyTorch version 2 or better. The first inference run will be very slow, but subsequent ones will be much faster.
+
+```py
+pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+```
+
+- When using the default VAE, keep it in `float32` to avoid costly `dtype` conversions before and after each generation. You only need to do this one before your first generation:
+
+```py
+pipe.upcast_vae()
+```
+
+As an alternative, you can also use a [16-bit VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) created by community member [`@madebyollin`](https://huggingface.co/madebyollin) that does not need to be upcasted to `float32`.
diff --git a/docs/source/en/using-diffusers/svd.md b/docs/source/en/using-diffusers/svd.md
new file mode 100644
index 000000000000..4fdb2608aa76
--- /dev/null
+++ b/docs/source/en/using-diffusers/svd.md
@@ -0,0 +1,133 @@
+
+
+# Stable Video Diffusion
+
+[[open-in-colab]]
+
+[Stable Video Diffusion](https://static1.squarespace.com/static/6213c340453c3f502425776e/t/655ce779b9d47d342a93c890/1700587395994/stable_video_diffusion.pdf) is a powerful image-to-video generation model that can generate high resolution (576x1024) 2-4 second videos conditioned on the input image.
+
+This guide will show you how to use SVD to short generate videos from images.
+
+Before you begin, make sure you have the following libraries installed:
+
+```py
+!pip install -q -U diffusers transformers accelerate
+```
+
+## Image to Video Generation
+
+The are two variants of SVD. [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid)
+and [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt). The svd checkpoint is trained to generate 14 frames and the svd-xt checkpoint is further
+finetuned to generate 25 frames.
+
+We will use the `svd-xt` checkpoint for this guide.
+
+```python
+import torch
+
+from diffusers import StableVideoDiffusionPipeline
+from diffusers.utils import load_image, export_to_video
+
+pipe = StableVideoDiffusionPipeline.from_pretrained(
+ "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
+)
+pipe.enable_model_cpu_offload()
+
+# Load the conditioning image
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
+image = image.resize((1024, 576))
+
+generator = torch.manual_seed(42)
+frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
+
+export_to_video(frames, "generated.mp4", fps=7)
+```
+
+
+
+
+Since generating videos is more memory intensive we can use the `decode_chunk_size` argument to control how many frames are decoded at once. This will reduce the memory usage. It's recommended to tweak this value based on your GPU memory.
+Setting `decode_chunk_size=1` will decode one frame at a time and will use the least amount of memory but the video might have some flickering.
+
+Additionally, we also use [model cpu offloading](../../optimization/memory#model-offloading) to reduce the memory usage.
+
+
+
+### Torch.compile
+
+You can achieve a 20-25% speed-up at the expense of slightly increased memory by compiling the UNet as follows:
+
+```diff
+- pipe.enable_model_cpu_offload()
++ pipe.to("cuda")
++ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+```
+
+### Low-memory
+
+Video generation is very memory intensive as we have to essentially generate `num_frames` all at once. The mechanism is very comparable to text-to-image generation with a high batch size. To reduce the memory requirement you have multiple options. The following options trade inference speed against lower memory requirement:
+- enable model offloading: Each component of the pipeline is offloaded to CPU once it's not needed anymore.
+- enable feed-forward chunking: The feed-forward layer runs in a loop instead of running with a single huge feed-forward batch size
+- reduce `decode_chunk_size`: This means that the VAE decodes frames in chunks instead of decoding them all together. **Note**: In addition to leading to a small slowdown, this method also slightly leads to video quality deterioration
+
+You can enable them as follows:
+
+```diff
+-pipe.enable_model_cpu_offload()
+-frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
++pipe.enable_model_cpu_offload()
++pipe.unet.enable_forward_chunking()
++frames = pipe(image, decode_chunk_size=2, generator=generator, num_frames=25).frames[0]
+```
+
+
+Including all these tricks should lower the memory requirement to less than 8GB VRAM.
+
+### Micro-conditioning
+
+Along with conditioning image Stable Diffusion Video also allows providing micro-conditioning that allows more control over the generated video.
+It accepts the following arguments:
+
+- `fps`: The frames per second of the generated video.
+- `motion_bucket_id`: The motion bucket id to use for the generated video. This can be used to control the motion of the generated video. Increasing the motion bucket id will increase the motion of the generated video.
+- `noise_aug_strength`: The amount of noise added to the conditioning image. The higher the values the less the video will resemble the conditioning image. Increasing this value will also increase the motion of the generated video.
+
+Here is an example of using micro-conditioning to generate a video with more motion.
+
+
+```python
+import torch
+
+from diffusers import StableVideoDiffusionPipeline
+from diffusers.utils import load_image, export_to_video
+
+pipe = StableVideoDiffusionPipeline.from_pretrained(
+ "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
+)
+pipe.enable_model_cpu_offload()
+
+# Load the conditioning image
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
+image = image.resize((1024, 576))
+
+generator = torch.manual_seed(42)
+frames = pipe(image, decode_chunk_size=8, generator=generator, motion_bucket_id=180, noise_aug_strength=0.1).frames[0]
+export_to_video(frames, "generated.mp4", fps=7)
+```
+
+
+
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
index bc612edbc20e..ca2a1521d3d5 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
@@ -1470,7 +1470,15 @@ def __call__(
height, width = self._default_height_width(height, width, adapter_image)
device = self._execution_device
- adapter_input = _preprocess_adapter_image(adapter_image, height, width).to(device)
+ if isinstance(adapter, MultiAdapter):
+ adapter_input = []
+ for one_image in adapter_image:
+ one_image = _preprocess_adapter_image(one_image, height, width)
+ one_image = one_image.to(device=device, dtype=adapter.dtype)
+ adapter_input.append(one_image)
+ else:
+ adapter_input = _preprocess_adapter_image(adapter_image, height, width)
+ adapter_input = adapter_input.to(device=device, dtype=adapter.dtype)
original_size = original_size or (height, width)
target_size = target_size or (height, width)
@@ -1643,10 +1651,14 @@ def denoising_value_valid(dnv):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 10. Prepare added time ids & embeddings & adapter features
- adapter_input = adapter_input.type(latents.dtype)
- adapter_state = adapter(adapter_input)
- for k, v in enumerate(adapter_state):
- adapter_state[k] = v * adapter_conditioning_scale
+ if isinstance(adapter, MultiAdapter):
+ adapter_state = adapter(adapter_input, adapter_conditioning_scale)
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v
+ else:
+ adapter_state = adapter(adapter_input)
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v * adapter_conditioning_scale
if num_images_per_prompt > 1:
for k, v in enumerate(adapter_state):
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
diff --git a/examples/controlnet/test_controlnet.py b/examples/controlnet/test_controlnet.py
new file mode 100644
index 000000000000..e62d095adaa2
--- /dev/null
+++ b/examples/controlnet/test_controlnet.py
@@ -0,0 +1,120 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class ControlNet(ExamplesTestsAccelerate):
+ def test_controlnet_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-8", "checkpoint-10", "checkpoint-12"},
+ )
+
+
+class ControlNetSDXL(ExamplesTestsAccelerate):
+ def test_controlnet_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet_sdxl.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
diff --git a/examples/custom_diffusion/test_custom_diffusion.py b/examples/custom_diffusion/test_custom_diffusion.py
new file mode 100644
index 000000000000..78f24c5172d6
--- /dev/null
+++ b/examples/custom_diffusion/test_custom_diffusion.py
@@ -0,0 +1,130 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class CustomDiffusion(ExamplesTestsAccelerate):
+ def test_custom_diffusion(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 1.0e-05
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --modifier_token
+ --no_safe_serialization
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_custom_diffusion_weights.bin")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, ".bin")))
+
+ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=
+ --resolution=64
+ --train_batch_size=1
+ --modifier_token=
+ --dataloader_num_workers=0
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ --no_safe_serialization
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=
+ --resolution=64
+ --train_batch_size=1
+ --modifier_token=
+ --dataloader_num_workers=0
+ --max_train_steps=9
+ --checkpointing_steps=2
+ --no_safe_serialization
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=
+ --resolution=64
+ --train_batch_size=1
+ --modifier_token=
+ --dataloader_num_workers=0
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ --no_safe_serialization
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
diff --git a/examples/dreambooth/test_dreambooth.py b/examples/dreambooth/test_dreambooth.py
new file mode 100644
index 000000000000..0c6c2a062325
--- /dev/null
+++ b/examples/dreambooth/test_dreambooth.py
@@ -0,0 +1,230 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import shutil
+import sys
+import tempfile
+
+from diffusers import DiffusionPipeline, UNet2DConditionModel
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBooth(ExamplesTestsAccelerate):
+ def test_dreambooth(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_if(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --pre_compute_text_embeddings
+ --tokenizer_max_length=77
+ --text_encoder_use_attention_mask
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_checkpointing(self):
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 5, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt {instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 5
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check can run the original fully trained output pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=2)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt {instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=2)
+
+ # check old checkpoints do not exist
+ self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+
+ # check new checkpoints exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
diff --git a/examples/dreambooth/test_dreambooth_lora.py b/examples/dreambooth/test_dreambooth_lora.py
new file mode 100644
index 000000000000..fc43269f732e
--- /dev/null
+++ b/examples/dreambooth/test_dreambooth_lora.py
@@ -0,0 +1,388 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+from diffusers import DiffusionPipeline # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRA(ExamplesTestsAccelerate):
+ def test_dreambooth_lora(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --train_text_encoder
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # check `text_encoder` is present at all.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ keys = lora_state_dict.keys()
+ is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
+ self.assertTrue(is_text_encoder_present)
+
+ # the names of the keys of the state dict should either start with `unet`
+ # or `text_encoder`.
+ is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
+ self.assertTrue(is_correct_naming)
+
+ def test_dreambooth_lora_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+ def test_dreambooth_lora_if_model(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --pre_compute_text_embeddings
+ --tokenizer_max_length=77
+ --text_encoder_use_attention_mask
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+
+class DreamBoothLoRASDXL(ExamplesTestsAccelerate):
+ def test_dreambooth_lora_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_sdxl_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
+ keys = lora_state_dict.keys()
+ starts_with_unet = all(
+ k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
+ )
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_sdxl_custom_captions(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --caption_column text
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --caption_column text
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe("a prompt", num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --train_text_encoder
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe("a prompt", num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index 41854501144b..c3f19efbdc38 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -300,7 +300,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--output_dir",
type=str,
- default="text-inversion-model",
+ default="dreambooth-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
diff --git a/examples/instruct_pix2pix/test_instruct_pix2pix.py b/examples/instruct_pix2pix/test_instruct_pix2pix.py
new file mode 100644
index 000000000000..c4d7500723fa
--- /dev/null
+++ b/examples/instruct_pix2pix/test_instruct_pix2pix.py
@@ -0,0 +1,101 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class InstructPix2Pix(ExamplesTestsAccelerate):
+ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/instruct_pix2pix/train_instruct_pix2pix.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/instructpix2pix-10-samples
+ --resolution=64
+ --random_flip
+ --train_batch_size=1
+ --max_train_steps=7
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --output_dir {tmpdir}
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/instruct_pix2pix/train_instruct_pix2pix.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/instructpix2pix-10-samples
+ --resolution=64
+ --random_flip
+ --train_batch_size=1
+ --max_train_steps=9
+ --checkpointing_steps=2
+ --output_dir {tmpdir}
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/instruct_pix2pix/train_instruct_pix2pix.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/instructpix2pix-10-samples
+ --resolution=64
+ --random_flip
+ --train_batch_size=1
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --output_dir {tmpdir}
+ --seed=0
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
diff --git a/examples/t2i_adapter/test_t2i_adapter.py b/examples/t2i_adapter/test_t2i_adapter.py
new file mode 100644
index 000000000000..fe8fd9d8c2d2
--- /dev/null
+++ b/examples/t2i_adapter/test_t2i_adapter.py
@@ -0,0 +1,51 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class T2IAdapter(ExamplesTestsAccelerate):
+ def test_t2i_adapter_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/t2i_adapter/train_t2i_adapter_sdxl.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --adapter_model_name_or_path=hf-internal-testing/tiny-adapter
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
diff --git a/examples/test_examples.py b/examples/test_examples.py
deleted file mode 100644
index 292c433a3395..000000000000
--- a/examples/test_examples.py
+++ /dev/null
@@ -1,1725 +0,0 @@
-# coding=utf-8
-# Copyright 2023 HuggingFace Inc..
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import logging
-import os
-import shutil
-import subprocess
-import sys
-import tempfile
-import unittest
-from typing import List
-
-import safetensors
-from accelerate.utils import write_basic_config
-
-from diffusers import DiffusionPipeline, UNet2DConditionModel
-
-
-logging.basicConfig(level=logging.DEBUG)
-
-logger = logging.getLogger()
-
-
-# These utils relate to ensuring the right error message is received when running scripts
-class SubprocessCallException(Exception):
- pass
-
-
-def run_command(command: List[str], return_stdout=False):
- """
- Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
- if an error occurred while running `command`
- """
- try:
- output = subprocess.check_output(command, stderr=subprocess.STDOUT)
- if return_stdout:
- if hasattr(output, "decode"):
- output = output.decode("utf-8")
- return output
- except subprocess.CalledProcessError as e:
- raise SubprocessCallException(
- f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
- ) from e
-
-
-stream_handler = logging.StreamHandler(sys.stdout)
-logger.addHandler(stream_handler)
-
-
-class ExamplesTestsAccelerate(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- cls._tmpdir = tempfile.mkdtemp()
- cls.configPath = os.path.join(cls._tmpdir, "default_config.yml")
-
- write_basic_config(save_location=cls.configPath)
- cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
-
- @classmethod
- def tearDownClass(cls):
- super().tearDownClass()
- shutil.rmtree(cls._tmpdir)
-
- def test_train_unconditional(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/unconditional_image_generation/train_unconditional.py
- --dataset_name hf-internal-testing/dummy_image_class_data
- --model_config_name_or_path diffusers/ddpm_dummy
- --resolution 64
- --output_dir {tmpdir}
- --train_batch_size 2
- --num_epochs 1
- --gradient_accumulation_steps 1
- --ddpm_num_inference_steps 2
- --learning_rate 1e-3
- --lr_warmup_steps 5
- """.split()
-
- run_command(self._launch_args + test_args, return_stdout=True)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
-
- def test_textual_inversion(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/textual_inversion/textual_inversion.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
- --train_data_dir docs/source/en/imgs
- --learnable_property object
- --placeholder_token
- --initializer_token a
- --validation_prompt
- --validation_steps 1
- --save_steps 1
- --num_vectors 2
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- """.split()
-
- run_command(self._launch_args + test_args)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
-
- def test_dreambooth(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
- --instance_data_dir docs/source/en/imgs
- --instance_prompt photo
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- """.split()
-
- run_command(self._launch_args + test_args)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
-
- def test_dreambooth_if(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
- --instance_data_dir docs/source/en/imgs
- --instance_prompt photo
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --pre_compute_text_embeddings
- --tokenizer_max_length=77
- --text_encoder_use_attention_mask
- """.split()
-
- run_command(self._launch_args + test_args)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
-
- def test_dreambooth_checkpointing(self):
- instance_prompt = "photo"
- pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
-
- with tempfile.TemporaryDirectory() as tmpdir:
- # Run training script with checkpointing
- # max_train_steps == 5, checkpointing_steps == 2
- # Should create checkpoints at steps 2, 4
-
- initial_run_args = f"""
- examples/dreambooth/train_dreambooth.py
- --pretrained_model_name_or_path {pretrained_model_name_or_path}
- --instance_data_dir docs/source/en/imgs
- --instance_prompt {instance_prompt}
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 5
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --seed=0
- """.split()
-
- run_command(self._launch_args + initial_run_args)
-
- # check can run the original fully trained output pipeline
- pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
- pipe(instance_prompt, num_inference_steps=2)
-
- # check checkpoint directories exist
- self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
- self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
-
- # check can run an intermediate checkpoint
- unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
- pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
- pipe(instance_prompt, num_inference_steps=2)
-
- # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
- shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
-
- # Run training script for 7 total steps resuming from checkpoint 4
-
- resume_run_args = f"""
- examples/dreambooth/train_dreambooth.py
- --pretrained_model_name_or_path {pretrained_model_name_or_path}
- --instance_data_dir docs/source/en/imgs
- --instance_prompt {instance_prompt}
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 7
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --resume_from_checkpoint=checkpoint-4
- --seed=0
- """.split()
-
- run_command(self._launch_args + resume_run_args)
-
- # check can run new fully trained pipeline
- pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
- pipe(instance_prompt, num_inference_steps=2)
-
- # check old checkpoints do not exist
- self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
-
- # check new checkpoints exist
- self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
- self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
-
- def test_dreambooth_lora(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth_lora.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
- --instance_data_dir docs/source/en/imgs
- --instance_prompt photo
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- """.split()
-
- run_command(self._launch_args + test_args)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
-
- # make sure the state_dict has the correct naming in the parameters.
- lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
- is_lora = all("lora" in k for k in lora_state_dict.keys())
- self.assertTrue(is_lora)
-
- # when not training the text encoder, all the parameters in the state dict should start
- # with `"unet"` in their names.
- starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
- self.assertTrue(starts_with_unet)
-
- def test_dreambooth_lora_with_text_encoder(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth_lora.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
- --instance_data_dir docs/source/en/imgs
- --instance_prompt photo
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --train_text_encoder
- --output_dir {tmpdir}
- """.split()
-
- run_command(self._launch_args + test_args)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
-
- # check `text_encoder` is present at all.
- lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
- keys = lora_state_dict.keys()
- is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
- self.assertTrue(is_text_encoder_present)
-
- # the names of the keys of the state dict should either start with `unet`
- # or `text_encoder`.
- is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
- self.assertTrue(is_correct_naming)
-
- def test_dreambooth_lora_if_model(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth_lora.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
- --instance_data_dir docs/source/en/imgs
- --instance_prompt photo
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --pre_compute_text_embeddings
- --tokenizer_max_length=77
- --text_encoder_use_attention_mask
- """.split()
-
- run_command(self._launch_args + test_args)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
-
- # make sure the state_dict has the correct naming in the parameters.
- lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
- is_lora = all("lora" in k for k in lora_state_dict.keys())
- self.assertTrue(is_lora)
-
- # when not training the text encoder, all the parameters in the state dict should start
- # with `"unet"` in their names.
- starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
- self.assertTrue(starts_with_unet)
-
- def test_dreambooth_lora_sdxl(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth_lora_sdxl.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
- --instance_data_dir docs/source/en/imgs
- --instance_prompt photo
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- """.split()
-
- run_command(self._launch_args + test_args)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
-
- # make sure the state_dict has the correct naming in the parameters.
- lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
- is_lora = all("lora" in k for k in lora_state_dict.keys())
- self.assertTrue(is_lora)
-
- # when not training the text encoder, all the parameters in the state dict should start
- # with `"unet"` in their names.
- starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
- self.assertTrue(starts_with_unet)
-
- def test_dreambooth_lora_sdxl_with_text_encoder(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth_lora_sdxl.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
- --instance_data_dir docs/source/en/imgs
- --instance_prompt photo
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --train_text_encoder
- """.split()
-
- run_command(self._launch_args + test_args)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
-
- # make sure the state_dict has the correct naming in the parameters.
- lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
- is_lora = all("lora" in k for k in lora_state_dict.keys())
- self.assertTrue(is_lora)
-
- # when not training the text encoder, all the parameters in the state dict should start
- # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
- keys = lora_state_dict.keys()
- starts_with_unet = all(
- k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
- )
- self.assertTrue(starts_with_unet)
-
- def test_dreambooth_lora_sdxl_custom_captions(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth_lora_sdxl.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
- --dataset_name hf-internal-testing/dummy_image_text_data
- --caption_column text
- --instance_prompt photo
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- """.split()
-
- run_command(self._launch_args + test_args)
-
- def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth_lora_sdxl.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
- --dataset_name hf-internal-testing/dummy_image_text_data
- --caption_column text
- --instance_prompt photo
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --train_text_encoder
- """.split()
-
- run_command(self._launch_args + test_args)
-
- def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
- pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
-
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth_lora_sdxl.py
- --pretrained_model_name_or_path {pipeline_path}
- --instance_data_dir docs/source/en/imgs
- --instance_prompt photo
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 7
- --checkpointing_steps=2
- --checkpoints_total_limit=2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- """.split()
-
- run_command(self._launch_args + test_args)
-
- pipe = DiffusionPipeline.from_pretrained(pipeline_path)
- pipe.load_lora_weights(tmpdir)
- pipe("a prompt", num_inference_steps=2)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- # checkpoint-2 should have been deleted
- {"checkpoint-4", "checkpoint-6"},
- )
-
- def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
- pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
-
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth_lora_sdxl.py
- --pretrained_model_name_or_path {pipeline_path}
- --instance_data_dir docs/source/en/imgs
- --instance_prompt photo
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 7
- --checkpointing_steps=2
- --checkpoints_total_limit=2
- --train_text_encoder
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- """.split()
-
- run_command(self._launch_args + test_args)
-
- pipe = DiffusionPipeline.from_pretrained(pipeline_path)
- pipe.load_lora_weights(tmpdir)
- pipe("a prompt", num_inference_steps=2)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- # checkpoint-2 should have been deleted
- {"checkpoint-4", "checkpoint-6"},
- )
-
- def test_custom_diffusion(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/custom_diffusion/train_custom_diffusion.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
- --instance_data_dir docs/source/en/imgs
- --instance_prompt
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 1.0e-05
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --modifier_token
- --no_safe_serialization
- --output_dir {tmpdir}
- """.split()
-
- run_command(self._launch_args + test_args)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_custom_diffusion_weights.bin")))
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, ".bin")))
-
- def test_text_to_image(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/text_to_image/train_text_to_image.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --center_crop
- --random_flip
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- """.split()
-
- run_command(self._launch_args + test_args)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
-
- def test_text_to_image_checkpointing(self):
- pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
- prompt = "a prompt"
-
- with tempfile.TemporaryDirectory() as tmpdir:
- # Run training script with checkpointing
- # max_train_steps == 5, checkpointing_steps == 2
- # Should create checkpoints at steps 2, 4
-
- initial_run_args = f"""
- examples/text_to_image/train_text_to_image.py
- --pretrained_model_name_or_path {pretrained_model_name_or_path}
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --center_crop
- --random_flip
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 5
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --seed=0
- """.split()
-
- run_command(self._launch_args + initial_run_args)
-
- pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
- pipe(prompt, num_inference_steps=2)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-2", "checkpoint-4"},
- )
-
- # check can run an intermediate checkpoint
- unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
- pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
- pipe(prompt, num_inference_steps=2)
-
- # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
- shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
-
- # Run training script for 7 total steps resuming from checkpoint 4
-
- resume_run_args = f"""
- examples/text_to_image/train_text_to_image.py
- --pretrained_model_name_or_path {pretrained_model_name_or_path}
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --center_crop
- --random_flip
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 7
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --resume_from_checkpoint=checkpoint-4
- --seed=0
- """.split()
-
- run_command(self._launch_args + resume_run_args)
-
- # check can run new fully trained pipeline
- pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
- pipe(prompt, num_inference_steps=2)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {
- # no checkpoint-2 -> check old checkpoints do not exist
- # check new checkpoints exist
- "checkpoint-4",
- "checkpoint-6",
- },
- )
-
- def test_text_to_image_checkpointing_use_ema(self):
- pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
- prompt = "a prompt"
-
- with tempfile.TemporaryDirectory() as tmpdir:
- # Run training script with checkpointing
- # max_train_steps == 5, checkpointing_steps == 2
- # Should create checkpoints at steps 2, 4
-
- initial_run_args = f"""
- examples/text_to_image/train_text_to_image.py
- --pretrained_model_name_or_path {pretrained_model_name_or_path}
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --center_crop
- --random_flip
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 5
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --use_ema
- --seed=0
- """.split()
-
- run_command(self._launch_args + initial_run_args)
-
- pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
- pipe(prompt, num_inference_steps=2)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-2", "checkpoint-4"},
- )
-
- # check can run an intermediate checkpoint
- unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
- pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
- pipe(prompt, num_inference_steps=2)
-
- # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
- shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
-
- # Run training script for 7 total steps resuming from checkpoint 4
-
- resume_run_args = f"""
- examples/text_to_image/train_text_to_image.py
- --pretrained_model_name_or_path {pretrained_model_name_or_path}
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --center_crop
- --random_flip
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 7
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --resume_from_checkpoint=checkpoint-4
- --use_ema
- --seed=0
- """.split()
-
- run_command(self._launch_args + resume_run_args)
-
- # check can run new fully trained pipeline
- pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
- pipe(prompt, num_inference_steps=2)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {
- # no checkpoint-2 -> check old checkpoints do not exist
- # check new checkpoints exist
- "checkpoint-4",
- "checkpoint-6",
- },
- )
-
- def test_text_to_image_checkpointing_checkpoints_total_limit(self):
- pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
- prompt = "a prompt"
-
- with tempfile.TemporaryDirectory() as tmpdir:
- # Run training script with checkpointing
- # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
- # Should create checkpoints at steps 2, 4, 6
- # with checkpoint at step 2 deleted
-
- initial_run_args = f"""
- examples/text_to_image/train_text_to_image.py
- --pretrained_model_name_or_path {pretrained_model_name_or_path}
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --center_crop
- --random_flip
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 7
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --checkpoints_total_limit=2
- --seed=0
- """.split()
-
- run_command(self._launch_args + initial_run_args)
-
- pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
- pipe(prompt, num_inference_steps=2)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- # checkpoint-2 should have been deleted
- {"checkpoint-4", "checkpoint-6"},
- )
-
- def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
- pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
- prompt = "a prompt"
-
- with tempfile.TemporaryDirectory() as tmpdir:
- # Run training script with checkpointing
- # max_train_steps == 9, checkpointing_steps == 2
- # Should create checkpoints at steps 2, 4, 6, 8
-
- initial_run_args = f"""
- examples/text_to_image/train_text_to_image.py
- --pretrained_model_name_or_path {pretrained_model_name_or_path}
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --center_crop
- --random_flip
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 9
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --seed=0
- """.split()
-
- run_command(self._launch_args + initial_run_args)
-
- pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
- pipe(prompt, num_inference_steps=2)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
- )
-
- # resume and we should try to checkpoint at 10, where we'll have to remove
- # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
-
- resume_run_args = f"""
- examples/text_to_image/train_text_to_image.py
- --pretrained_model_name_or_path {pretrained_model_name_or_path}
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --center_crop
- --random_flip
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 11
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --resume_from_checkpoint=checkpoint-8
- --checkpoints_total_limit=3
- --seed=0
- """.split()
-
- run_command(self._launch_args + resume_run_args)
-
- pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
- pipe(prompt, num_inference_steps=2)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
- )
-
- def test_text_to_image_sdxl(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/text_to_image/train_text_to_image_sdxl.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --center_crop
- --random_flip
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- """.split()
-
- run_command(self._launch_args + test_args)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
-
- def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
- pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
- prompt = "a prompt"
-
- with tempfile.TemporaryDirectory() as tmpdir:
- # Run training script with checkpointing
- # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
- # Should create checkpoints at steps 2, 4, 6
- # with checkpoint at step 2 deleted
-
- initial_run_args = f"""
- examples/text_to_image/train_text_to_image_lora.py
- --pretrained_model_name_or_path {pretrained_model_name_or_path}
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --center_crop
- --random_flip
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 7
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --checkpoints_total_limit=2
- --seed=0
- --num_validation_images=0
- """.split()
-
- run_command(self._launch_args + initial_run_args)
-
- pipe = DiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
- )
- pipe.load_lora_weights(tmpdir)
- pipe(prompt, num_inference_steps=2)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- # checkpoint-2 should have been deleted
- {"checkpoint-4", "checkpoint-6"},
- )
-
- def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):
- prompt = "a prompt"
- pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
-
- with tempfile.TemporaryDirectory() as tmpdir:
- # Run training script with checkpointing
- # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
- # Should create checkpoints at steps 2, 4, 6
- # with checkpoint at step 2 deleted
-
- initial_run_args = f"""
- examples/text_to_image/train_text_to_image_lora_sdxl.py
- --pretrained_model_name_or_path {pipeline_path}
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 7
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --checkpoints_total_limit=2
- """.split()
-
- run_command(self._launch_args + initial_run_args)
-
- pipe = DiffusionPipeline.from_pretrained(pipeline_path)
- pipe.load_lora_weights(tmpdir)
- pipe(prompt, num_inference_steps=2)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- # checkpoint-2 should have been deleted
- {"checkpoint-4", "checkpoint-6"},
- )
-
- def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
- prompt = "a prompt"
- pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
-
- with tempfile.TemporaryDirectory() as tmpdir:
- # Run training script with checkpointing
- # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
- # Should create checkpoints at steps 2, 4, 6
- # with checkpoint at step 2 deleted
-
- initial_run_args = f"""
- examples/text_to_image/train_text_to_image_lora_sdxl.py
- --pretrained_model_name_or_path {pipeline_path}
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 7
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --train_text_encoder
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --checkpoints_total_limit=2
- """.split()
-
- run_command(self._launch_args + initial_run_args)
-
- pipe = DiffusionPipeline.from_pretrained(pipeline_path)
- pipe.load_lora_weights(tmpdir)
- pipe(prompt, num_inference_steps=2)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- # checkpoint-2 should have been deleted
- {"checkpoint-4", "checkpoint-6"},
- )
-
- def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
- pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
- prompt = "a prompt"
-
- with tempfile.TemporaryDirectory() as tmpdir:
- # Run training script with checkpointing
- # max_train_steps == 9, checkpointing_steps == 2
- # Should create checkpoints at steps 2, 4, 6, 8
-
- initial_run_args = f"""
- examples/text_to_image/train_text_to_image_lora.py
- --pretrained_model_name_or_path {pretrained_model_name_or_path}
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --center_crop
- --random_flip
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 9
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --seed=0
- --num_validation_images=0
- """.split()
-
- run_command(self._launch_args + initial_run_args)
-
- pipe = DiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
- )
- pipe.load_lora_weights(tmpdir)
- pipe(prompt, num_inference_steps=2)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
- )
-
- # resume and we should try to checkpoint at 10, where we'll have to remove
- # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
-
- resume_run_args = f"""
- examples/text_to_image/train_text_to_image_lora.py
- --pretrained_model_name_or_path {pretrained_model_name_or_path}
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --center_crop
- --random_flip
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 11
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=2
- --resume_from_checkpoint=checkpoint-8
- --checkpoints_total_limit=3
- --seed=0
- --num_validation_images=0
- """.split()
-
- run_command(self._launch_args + resume_run_args)
-
- pipe = DiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
- )
- pipe.load_lora_weights(tmpdir)
- pipe(prompt, num_inference_steps=2)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
- )
-
- def test_unconditional_checkpointing_checkpoints_total_limit(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- initial_run_args = f"""
- examples/unconditional_image_generation/train_unconditional.py
- --dataset_name hf-internal-testing/dummy_image_class_data
- --model_config_name_or_path diffusers/ddpm_dummy
- --resolution 64
- --output_dir {tmpdir}
- --train_batch_size 1
- --num_epochs 1
- --gradient_accumulation_steps 1
- --ddpm_num_inference_steps 2
- --learning_rate 1e-3
- --lr_warmup_steps 5
- --checkpointing_steps=2
- --checkpoints_total_limit=2
- """.split()
-
- run_command(self._launch_args + initial_run_args)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- # checkpoint-2 should have been deleted
- {"checkpoint-4", "checkpoint-6"},
- )
-
- def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- initial_run_args = f"""
- examples/unconditional_image_generation/train_unconditional.py
- --dataset_name hf-internal-testing/dummy_image_class_data
- --model_config_name_or_path diffusers/ddpm_dummy
- --resolution 64
- --output_dir {tmpdir}
- --train_batch_size 1
- --num_epochs 1
- --gradient_accumulation_steps 1
- --ddpm_num_inference_steps 2
- --learning_rate 1e-3
- --lr_warmup_steps 5
- --checkpointing_steps=1
- """.split()
-
- run_command(self._launch_args + initial_run_args)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-1", "checkpoint-2", "checkpoint-3", "checkpoint-4", "checkpoint-5", "checkpoint-6"},
- )
-
- resume_run_args = f"""
- examples/unconditional_image_generation/train_unconditional.py
- --dataset_name hf-internal-testing/dummy_image_class_data
- --model_config_name_or_path diffusers/ddpm_dummy
- --resolution 64
- --output_dir {tmpdir}
- --train_batch_size 1
- --num_epochs 2
- --gradient_accumulation_steps 1
- --ddpm_num_inference_steps 2
- --learning_rate 1e-3
- --lr_warmup_steps 5
- --resume_from_checkpoint=checkpoint-6
- --checkpointing_steps=2
- --checkpoints_total_limit=3
- """.split()
-
- run_command(self._launch_args + resume_run_args)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-8", "checkpoint-10", "checkpoint-12"},
- )
-
- def test_textual_inversion_checkpointing(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/textual_inversion/textual_inversion.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
- --train_data_dir docs/source/en/imgs
- --learnable_property object
- --placeholder_token
- --initializer_token a
- --validation_prompt
- --validation_steps 1
- --save_steps 1
- --num_vectors 2
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 3
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=1
- --checkpoints_total_limit=2
- """.split()
-
- run_command(self._launch_args + test_args)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-2", "checkpoint-3"},
- )
-
- def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/textual_inversion/textual_inversion.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
- --train_data_dir docs/source/en/imgs
- --learnable_property object
- --placeholder_token
- --initializer_token a
- --validation_prompt
- --validation_steps 1
- --save_steps 1
- --num_vectors 2
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 3
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=1
- """.split()
-
- run_command(self._launch_args + test_args)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-1", "checkpoint-2", "checkpoint-3"},
- )
-
- resume_run_args = f"""
- examples/textual_inversion/textual_inversion.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
- --train_data_dir docs/source/en/imgs
- --learnable_property object
- --placeholder_token
- --initializer_token a
- --validation_prompt
- --validation_steps 1
- --save_steps 1
- --num_vectors 2
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 4
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --checkpointing_steps=1
- --resume_from_checkpoint=checkpoint-3
- --checkpoints_total_limit=2
- """.split()
-
- run_command(self._launch_args + resume_run_args)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-3", "checkpoint-4"},
- )
-
- def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/instruct_pix2pix/train_instruct_pix2pix.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --dataset_name=hf-internal-testing/instructpix2pix-10-samples
- --resolution=64
- --random_flip
- --train_batch_size=1
- --max_train_steps=7
- --checkpointing_steps=2
- --checkpoints_total_limit=2
- --output_dir {tmpdir}
- --seed=0
- """.split()
-
- run_command(self._launch_args + test_args)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-4", "checkpoint-6"},
- )
-
- def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/instruct_pix2pix/train_instruct_pix2pix.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --dataset_name=hf-internal-testing/instructpix2pix-10-samples
- --resolution=64
- --random_flip
- --train_batch_size=1
- --max_train_steps=9
- --checkpointing_steps=2
- --output_dir {tmpdir}
- --seed=0
- """.split()
-
- run_command(self._launch_args + test_args)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
- )
-
- resume_run_args = f"""
- examples/instruct_pix2pix/train_instruct_pix2pix.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --dataset_name=hf-internal-testing/instructpix2pix-10-samples
- --resolution=64
- --random_flip
- --train_batch_size=1
- --max_train_steps=11
- --checkpointing_steps=2
- --output_dir {tmpdir}
- --seed=0
- --resume_from_checkpoint=checkpoint-8
- --checkpoints_total_limit=3
- """.split()
-
- run_command(self._launch_args + resume_run_args)
-
- # check checkpoint directories exist
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
- )
-
- def test_dreambooth_checkpointing_checkpoints_total_limit(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --instance_data_dir=docs/source/en/imgs
- --output_dir={tmpdir}
- --instance_prompt=prompt
- --resolution=64
- --train_batch_size=1
- --gradient_accumulation_steps=1
- --max_train_steps=6
- --checkpoints_total_limit=2
- --checkpointing_steps=2
- """.split()
-
- run_command(self._launch_args + test_args)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-4", "checkpoint-6"},
- )
-
- def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --instance_data_dir=docs/source/en/imgs
- --output_dir={tmpdir}
- --instance_prompt=prompt
- --resolution=64
- --train_batch_size=1
- --gradient_accumulation_steps=1
- --max_train_steps=9
- --checkpointing_steps=2
- """.split()
-
- run_command(self._launch_args + test_args)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
- )
-
- resume_run_args = f"""
- examples/dreambooth/train_dreambooth.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --instance_data_dir=docs/source/en/imgs
- --output_dir={tmpdir}
- --instance_prompt=prompt
- --resolution=64
- --train_batch_size=1
- --gradient_accumulation_steps=1
- --max_train_steps=11
- --checkpointing_steps=2
- --resume_from_checkpoint=checkpoint-8
- --checkpoints_total_limit=3
- """.split()
-
- run_command(self._launch_args + resume_run_args)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
- )
-
- def test_dreambooth_lora_checkpointing_checkpoints_total_limit(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth_lora.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --instance_data_dir=docs/source/en/imgs
- --output_dir={tmpdir}
- --instance_prompt=prompt
- --resolution=64
- --train_batch_size=1
- --gradient_accumulation_steps=1
- --max_train_steps=6
- --checkpoints_total_limit=2
- --checkpointing_steps=2
- """.split()
-
- run_command(self._launch_args + test_args)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-4", "checkpoint-6"},
- )
-
- def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/dreambooth/train_dreambooth_lora.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --instance_data_dir=docs/source/en/imgs
- --output_dir={tmpdir}
- --instance_prompt=prompt
- --resolution=64
- --train_batch_size=1
- --gradient_accumulation_steps=1
- --max_train_steps=9
- --checkpointing_steps=2
- """.split()
-
- run_command(self._launch_args + test_args)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
- )
-
- resume_run_args = f"""
- examples/dreambooth/train_dreambooth_lora.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --instance_data_dir=docs/source/en/imgs
- --output_dir={tmpdir}
- --instance_prompt=prompt
- --resolution=64
- --train_batch_size=1
- --gradient_accumulation_steps=1
- --max_train_steps=11
- --checkpointing_steps=2
- --resume_from_checkpoint=checkpoint-8
- --checkpoints_total_limit=3
- """.split()
-
- run_command(self._launch_args + resume_run_args)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
- )
-
- def test_controlnet_checkpointing_checkpoints_total_limit(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/controlnet/train_controlnet.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --dataset_name=hf-internal-testing/fill10
- --output_dir={tmpdir}
- --resolution=64
- --train_batch_size=1
- --gradient_accumulation_steps=1
- --max_train_steps=6
- --checkpoints_total_limit=2
- --checkpointing_steps=2
- --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
- """.split()
-
- run_command(self._launch_args + test_args)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-4", "checkpoint-6"},
- )
-
- def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/controlnet/train_controlnet.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --dataset_name=hf-internal-testing/fill10
- --output_dir={tmpdir}
- --resolution=64
- --train_batch_size=1
- --gradient_accumulation_steps=1
- --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
- --max_train_steps=9
- --checkpointing_steps=2
- """.split()
-
- run_command(self._launch_args + test_args)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
- )
-
- resume_run_args = f"""
- examples/controlnet/train_controlnet.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --dataset_name=hf-internal-testing/fill10
- --output_dir={tmpdir}
- --resolution=64
- --train_batch_size=1
- --gradient_accumulation_steps=1
- --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
- --max_train_steps=11
- --checkpointing_steps=2
- --resume_from_checkpoint=checkpoint-8
- --checkpoints_total_limit=3
- """.split()
-
- run_command(self._launch_args + resume_run_args)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-8", "checkpoint-10", "checkpoint-12"},
- )
-
- def test_controlnet_sdxl(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/controlnet/train_controlnet_sdxl.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
- --dataset_name=hf-internal-testing/fill10
- --output_dir={tmpdir}
- --resolution=64
- --train_batch_size=1
- --gradient_accumulation_steps=1
- --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
- --max_train_steps=9
- --checkpointing_steps=2
- """.split()
-
- run_command(self._launch_args + test_args)
-
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
-
- def test_t2i_adapter_sdxl(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/t2i_adapter/train_t2i_adapter_sdxl.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
- --adapter_model_name_or_path=hf-internal-testing/tiny-adapter
- --dataset_name=hf-internal-testing/fill10
- --output_dir={tmpdir}
- --resolution=64
- --train_batch_size=1
- --gradient_accumulation_steps=1
- --max_train_steps=9
- --checkpointing_steps=2
- """.split()
-
- run_command(self._launch_args + test_args)
-
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
-
- def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/custom_diffusion/train_custom_diffusion.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --instance_data_dir=docs/source/en/imgs
- --output_dir={tmpdir}
- --instance_prompt=
- --resolution=64
- --train_batch_size=1
- --modifier_token=
- --dataloader_num_workers=0
- --max_train_steps=6
- --checkpoints_total_limit=2
- --checkpointing_steps=2
- --no_safe_serialization
- """.split()
-
- run_command(self._launch_args + test_args)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-4", "checkpoint-6"},
- )
-
- def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/custom_diffusion/train_custom_diffusion.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --instance_data_dir=docs/source/en/imgs
- --output_dir={tmpdir}
- --instance_prompt=
- --resolution=64
- --train_batch_size=1
- --modifier_token=
- --dataloader_num_workers=0
- --max_train_steps=9
- --checkpointing_steps=2
- --no_safe_serialization
- """.split()
-
- run_command(self._launch_args + test_args)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
- )
-
- resume_run_args = f"""
- examples/custom_diffusion/train_custom_diffusion.py
- --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
- --instance_data_dir=docs/source/en/imgs
- --output_dir={tmpdir}
- --instance_prompt=
- --resolution=64
- --train_batch_size=1
- --modifier_token=
- --dataloader_num_workers=0
- --max_train_steps=11
- --checkpointing_steps=2
- --resume_from_checkpoint=checkpoint-8
- --checkpoints_total_limit=3
- --no_safe_serialization
- """.split()
-
- run_command(self._launch_args + resume_run_args)
-
- self.assertEqual(
- {x for x in os.listdir(tmpdir) if "checkpoint" in x},
- {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
- )
-
- def test_text_to_image_lora_sdxl(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/text_to_image/train_text_to_image_lora_sdxl.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- """.split()
-
- run_command(self._launch_args + test_args)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
-
- # make sure the state_dict has the correct naming in the parameters.
- lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
- is_lora = all("lora" in k for k in lora_state_dict.keys())
- self.assertTrue(is_lora)
-
- def test_text_to_image_lora_sdxl_with_text_encoder(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- test_args = f"""
- examples/text_to_image/train_text_to_image_lora_sdxl.py
- --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
- --dataset_name hf-internal-testing/dummy_image_text_data
- --resolution 64
- --train_batch_size 1
- --gradient_accumulation_steps 1
- --max_train_steps 2
- --learning_rate 5.0e-04
- --scale_lr
- --lr_scheduler constant
- --lr_warmup_steps 0
- --output_dir {tmpdir}
- --train_text_encoder
- """.split()
-
- run_command(self._launch_args + test_args)
- # save_pretrained smoke test
- self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
-
- # make sure the state_dict has the correct naming in the parameters.
- lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
- is_lora = all("lora" in k for k in lora_state_dict.keys())
- self.assertTrue(is_lora)
-
- # when not training the text encoder, all the parameters in the state dict should start
- # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
- keys = lora_state_dict.keys()
- starts_with_unet = all(
- k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
- )
- self.assertTrue(starts_with_unet)
diff --git a/examples/test_examples_utils.py b/examples/test_examples_utils.py
new file mode 100644
index 000000000000..3a697c65c4c7
--- /dev/null
+++ b/examples/test_examples_utils.py
@@ -0,0 +1,61 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import shutil
+import subprocess
+import tempfile
+import unittest
+from typing import List
+
+from accelerate.utils import write_basic_config
+
+
+# These utils relate to ensuring the right error message is received when running scripts
+class SubprocessCallException(Exception):
+ pass
+
+
+def run_command(command: List[str], return_stdout=False):
+ """
+ Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
+ if an error occurred while running `command`
+ """
+ try:
+ output = subprocess.check_output(command, stderr=subprocess.STDOUT)
+ if return_stdout:
+ if hasattr(output, "decode"):
+ output = output.decode("utf-8")
+ return output
+ except subprocess.CalledProcessError as e:
+ raise SubprocessCallException(
+ f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
+ ) from e
+
+
+class ExamplesTestsAccelerate(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls._tmpdir = tempfile.mkdtemp()
+ cls.configPath = os.path.join(cls._tmpdir, "default_config.yml")
+
+ write_basic_config(save_location=cls.configPath)
+ cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
+
+ @classmethod
+ def tearDownClass(cls):
+ super().tearDownClass()
+ shutil.rmtree(cls._tmpdir)
diff --git a/examples/text_to_image/test_text_to_image.py b/examples/text_to_image/test_text_to_image.py
new file mode 100644
index 000000000000..308a038b5533
--- /dev/null
+++ b/examples/text_to_image/test_text_to_image.py
@@ -0,0 +1,373 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import shutil
+import sys
+import tempfile
+
+from diffusers import DiffusionPipeline, UNet2DConditionModel # noqa: E402
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class TextToImage(ExamplesTestsAccelerate):
+ def test_text_to_image(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_text_to_image_checkpointing(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 5, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 5
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {
+ # no checkpoint-2 -> check old checkpoints do not exist
+ # check new checkpoints exist
+ "checkpoint-4",
+ "checkpoint-6",
+ },
+ )
+
+ def test_text_to_image_checkpointing_use_ema(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 5, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 5
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --use_ema
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --use_ema
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {
+ # no checkpoint-2 -> check old checkpoints do not exist
+ # check new checkpoints exist
+ "checkpoint-4",
+ "checkpoint-6",
+ },
+ )
+
+ def test_text_to_image_checkpointing_checkpoints_total_limit(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 9, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4, 6, 8
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 9
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ # resume and we should try to checkpoint at 10, where we'll have to remove
+ # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 11
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+
+class TextToImageSDXL(ExamplesTestsAccelerate):
+ def test_text_to_image_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
diff --git a/examples/text_to_image/test_text_to_image_lora.py b/examples/text_to_image/test_text_to_image_lora.py
new file mode 100644
index 000000000000..83cbb78b2dc6
--- /dev/null
+++ b/examples/text_to_image/test_text_to_image_lora.py
@@ -0,0 +1,308 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+from diffusers import DiffusionPipeline # noqa: E402
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class TextToImageLoRA(ExamplesTestsAccelerate):
+ def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):
+ prompt = "a prompt"
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 9, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4, 6, 8
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 9
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ # resume and we should try to checkpoint at 10, where we'll have to remove
+ # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 11
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+
+class TextToImageLoRASDXL(ExamplesTestsAccelerate):
+ def test_text_to_image_lora_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ def test_text_to_image_lora_sdxl_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
+ keys = lora_state_dict.keys()
+ starts_with_unet = all(
+ k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
+ )
+ self.assertTrue(starts_with_unet)
+
+ def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
+ prompt = "a prompt"
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --train_text_encoder
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
diff --git a/examples/textual_inversion/test_textual_inversion.py b/examples/textual_inversion/test_textual_inversion.py
new file mode 100644
index 000000000000..a5d7bcb65dd3
--- /dev/null
+++ b/examples/textual_inversion/test_textual_inversion.py
@@ -0,0 +1,160 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class TextualInversion(ExamplesTestsAccelerate):
+ def test_textual_inversion(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --validation_prompt
+ --validation_steps 1
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
+
+ def test_textual_inversion_checkpointing(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --validation_prompt
+ --validation_steps 1
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 3
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-3"},
+ )
+
+ def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --validation_prompt
+ --validation_steps 1
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 3
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-1", "checkpoint-2", "checkpoint-3"},
+ )
+
+ resume_run_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --validation_prompt
+ --validation_steps 1
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --resume_from_checkpoint=checkpoint-3
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-3", "checkpoint-4"},
+ )
diff --git a/examples/unconditional_image_generation/test_unconditional.py b/examples/unconditional_image_generation/test_unconditional.py
new file mode 100644
index 000000000000..b7e19abe9f6e
--- /dev/null
+++ b/examples/unconditional_image_generation/test_unconditional.py
@@ -0,0 +1,130 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class Unconditional(ExamplesTestsAccelerate):
+ def test_train_unconditional(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 2
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ """.split()
+
+ run_command(self._launch_args + test_args, return_stdout=True)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_unconditional_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ initial_run_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ initial_run_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --checkpointing_steps=1
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-1", "checkpoint-2", "checkpoint-3", "checkpoint-4", "checkpoint-5", "checkpoint-6"},
+ )
+
+ resume_run_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 2
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --resume_from_checkpoint=checkpoint-6
+ --checkpointing_steps=2
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-8", "checkpoint-10", "checkpoint-12"},
+ )
diff --git a/examples/wuerstchen/text_to_image/requirements.txt b/examples/wuerstchen/text_to_image/requirements.txt
index a58ad09eca55..f734c1659d32 100644
--- a/examples/wuerstchen/text_to_image/requirements.txt
+++ b/examples/wuerstchen/text_to_image/requirements.txt
@@ -5,3 +5,4 @@ wandb
huggingface-cli
bitsandbytes
deepspeed
+peft>=0.6.0
diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
index 33de3d3bf777..bca018d8df23 100644
--- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
@@ -31,14 +31,14 @@
from datasets import load_dataset
from huggingface_hub import create_repo, hf_hub_download, upload_folder
from modeling_efficient_net_encoder import EfficientNetEncoder
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm import tqdm
from transformers import CLIPTextModel, PreTrainedTokenizerFast
from transformers.utils import ContextManagers
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler, WuerstchenPriorPipeline
-from diffusers.loaders import AttnProcsLayers
-from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
@@ -139,17 +139,17 @@ def save_model_card(
f.write(yaml + model_card)
-def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator, weight_dtype, epoch):
+def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch):
logger.info("Running validation... ")
pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path,
+ prior=accelerator.unwrap_model(prior),
prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
- pipeline.prior_prior.set_attn_processor(attn_processors)
pipeline.set_progress_bar_config(disable=True)
if args.seed is None:
@@ -159,7 +159,7 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
images = []
for i in range(len(args.validation_prompts)):
- with torch.autocast("cuda"):
+ with torch.cuda.amp.autocast():
image = pipeline(
args.validation_prompts[i],
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
@@ -167,7 +167,6 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
height=args.resolution,
width=args.resolution,
).images[0]
-
images.append(image)
for tracker in accelerator.trackers:
@@ -527,11 +526,50 @@ def deepspeed_zero_init_disabled_context_manager():
prior.to(accelerator.device, dtype=weight_dtype)
# lora attn processor
- lora_attn_procs = {}
- for name in prior.attn_processors.keys():
- lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=prior.config["c"], rank=args.rank)
- prior.set_attn_processor(lora_attn_procs)
- lora_layers = AttnProcsLayers(prior.attn_processors)
+ prior_lora_config = LoraConfig(
+ r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"]
+ )
+ prior.add_adapter(prior_lora_config)
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ prior_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(accelerator.unwrap_model(prior))):
+ prior_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ WuerstchenPriorPipeline.save_lora_weights(
+ output_dir,
+ unet_lora_layers=prior_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ prior_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(accelerator.unwrap_model(prior))):
+ prior_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = WuerstchenPriorPipeline.lora_state_dict(input_dir)
+ WuerstchenPriorPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=prior_)
+ WuerstchenPriorPipeline.load_lora_into_text_encoder(
+ lora_state_dict,
+ network_alphas=network_alphas,
+ )
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
@@ -547,8 +585,9 @@ def deepspeed_zero_init_disabled_context_manager():
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
+ params_to_optimize = list(filter(lambda p: p.requires_grad, prior.parameters()))
optimizer = optimizer_cls(
- lora_layers.parameters(),
+ params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -674,8 +713,8 @@ def collate_fn(examples):
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
- lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- lora_layers, optimizer, train_dataloader, lr_scheduler
+ prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ prior, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -782,7 +821,7 @@ def collate_fn(examples):
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
- accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm)
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
@@ -828,17 +867,19 @@ def collate_fn(examples):
if accelerator.is_main_process:
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
- log_validation(
- text_encoder, tokenizer, prior.attn_processors, args, accelerator, weight_dtype, global_step
- )
+ log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step)
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
+ prior = accelerator.unwrap_model(prior)
prior = prior.to(torch.float32)
+
+ prior_lora_state_dict = get_peft_model_state_dict(prior)
+
WuerstchenPriorPipeline.save_lora_weights(
- os.path.join(args.output_dir, "prior_lora"),
- unet_lora_layers=lora_layers,
+ save_directory=args.output_dir,
+ unet_lora_layers=prior_lora_state_dict,
)
# Run a final round of inference.
@@ -849,11 +890,12 @@ def collate_fn(examples):
args.pretrained_decoder_model_name_or_path,
prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer,
+ torch_dtype=weight_dtype,
)
- pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype)
- # load lora weights
- pipeline.prior_pipe.load_lora_weights(os.path.join(args.output_dir, "prior_lora"))
+ pipeline = pipeline.to(accelerator.device)
+ # load lora weights
+ pipeline.prior_pipe.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
pipeline.set_progress_bar_config(disable=True)
if args.seed is None:
@@ -862,7 +904,7 @@ def collate_fn(examples):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
for i in range(len(args.validation_prompts)):
- with torch.autocast("cuda"):
+ with torch.cuda.amp.autocast():
image = pipeline(
args.validation_prompts[i],
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py
new file mode 100644
index 000000000000..3243ce294b26
--- /dev/null
+++ b/scripts/convert_svd_to_diffusers.py
@@ -0,0 +1,730 @@
+from diffusers.utils import is_accelerate_available, logging
+
+
+if is_accelerate_available():
+ pass
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ if controlnet:
+ unet_params = original_config.model.params.control_stage_config.params
+ else:
+ if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
+ unet_params = original_config.model.params.unet_config.params
+ else:
+ unet_params = original_config.model.params.network_config.params
+
+ vae_params = original_config.model.params.first_stage_config.params.encoder_config.params
+
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
+
+ down_block_types = []
+ resolution = 1
+ for i in range(len(block_out_channels)):
+ block_type = (
+ "CrossAttnDownBlockSpatioTemporal"
+ if resolution in unet_params.attention_resolutions
+ else "DownBlockSpatioTemporal"
+ )
+ down_block_types.append(block_type)
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ block_type = (
+ "CrossAttnUpBlockSpatioTemporal"
+ if resolution in unet_params.attention_resolutions
+ else "UpBlockSpatioTemporal"
+ )
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ if unet_params.transformer_depth is not None:
+ transformer_layers_per_block = (
+ unet_params.transformer_depth
+ if isinstance(unet_params.transformer_depth, int)
+ else list(unet_params.transformer_depth)
+ )
+ else:
+ transformer_layers_per_block = 1
+
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
+
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
+ use_linear_projection = (
+ unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
+ )
+ if use_linear_projection:
+ # stable diffusion 2-base-512 and 2-768
+ if head_dim is None:
+ head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
+ head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
+
+ class_embed_type = None
+ addition_embed_type = None
+ addition_time_embed_dim = None
+ projection_class_embeddings_input_dim = None
+ context_dim = None
+
+ if unet_params.context_dim is not None:
+ context_dim = (
+ unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
+ )
+
+ if "num_classes" in unet_params:
+ if unet_params.num_classes == "sequential":
+ addition_time_embed_dim = 256
+ assert "adm_in_channels" in unet_params
+ projection_class_embeddings_input_dim = unet_params.adm_in_channels
+
+ config = {
+ "sample_size": image_size // vae_scale_factor,
+ "in_channels": unet_params.in_channels,
+ "down_block_types": tuple(down_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "layers_per_block": unet_params.num_res_blocks,
+ "cross_attention_dim": context_dim,
+ "attention_head_dim": head_dim,
+ "use_linear_projection": use_linear_projection,
+ "class_embed_type": class_embed_type,
+ "addition_embed_type": addition_embed_type,
+ "addition_time_embed_dim": addition_time_embed_dim,
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
+ "transformer_layers_per_block": transformer_layers_per_block,
+ }
+
+ if "disable_self_attentions" in unet_params:
+ config["only_cross_attention"] = unet_params.disable_self_attentions
+
+ if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
+ config["num_class_embeds"] = unet_params.num_classes
+
+ if controlnet:
+ config["conditioning_channels"] = unet_params.hint_channels
+ else:
+ config["out_channels"] = unet_params.out_channels
+ config["up_block_types"] = tuple(up_block_types)
+
+ return config
+
+
+def assign_to_checkpoint(
+ paths,
+ checkpoint,
+ old_checkpoint,
+ attention_paths_to_split=None,
+ additional_replacements=None,
+ config=None,
+ mid_block_suffix="",
+):
+ """
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
+ attention layers, and takes into account additional replacements that may arise.
+
+ Assigns the weights to the new checkpoint.
+ """
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ # Splits the attention layers into three variables.
+ if attention_paths_to_split is not None:
+ for path, path_map in attention_paths_to_split.items():
+ old_tensor = old_checkpoint[path]
+ channels = old_tensor.shape[0] // 3
+
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
+
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
+
+ if mid_block_suffix is not None:
+ mid_block_suffix = f".{mid_block_suffix}"
+ else:
+ mid_block_suffix = ""
+
+ for path in paths:
+ new_path = path["new"]
+
+ # These have already been assigned
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+ continue
+
+ # Global renaming happens here
+ new_path = new_path.replace("middle_block.0", f"mid_block.resnets.0{mid_block_suffix}")
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
+ new_path = new_path.replace("middle_block.2", f"mid_block.resnets.1{mid_block_suffix}")
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ if new_path == "mid_block.resnets.0.spatial_res_block.norm1.weight":
+ print("yeyy")
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
+ shape = old_checkpoint[path["old"]].shape
+ if is_attn_weight and len(shape) == 3:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
+ elif is_attn_weight and len(shape) == 4:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
+ else:
+ checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
+
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
+
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+ new_item = new_item.replace("time_stack", "temporal_transformer_blocks")
+
+ new_item = new_item.replace("time_pos_embed.0.bias", "time_pos_embed.linear_1.bias")
+ new_item = new_item.replace("time_pos_embed.0.weight", "time_pos_embed.linear_1.weight")
+ new_item = new_item.replace("time_pos_embed.2.bias", "time_pos_embed.linear_2.bias")
+ new_item = new_item.replace("time_pos_embed.2.weight", "time_pos_embed.linear_2.weight")
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = new_item.replace("time_stack.", "")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def convert_ldm_unet_checkpoint(
+ checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
+):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+
+ if skip_extract_state_dict:
+ unet_state_dict = checkpoint
+ else:
+ # extract state_dict for UNet
+ unet_state_dict = {}
+ keys = list(checkpoint.keys())
+
+ unet_key = "model.diffusion_model."
+
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
+ logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
+ logger.warning(
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
+ )
+ for key in keys:
+ if key.startswith("model.diffusion_model"):
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
+ else:
+ if sum(k.startswith("model_ema") for k in keys) > 100:
+ logger.warning(
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
+ )
+
+ for key in keys:
+ if key.startswith(unet_key):
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ if config["class_embed_type"] is None:
+ # No parameters to port
+ ...
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+ else:
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
+
+ # if config["addition_embed_type"] == "text_time":
+ new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+ new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+ new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+ new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ spatial_resnets = [
+ key
+ for key in input_blocks[i]
+ if f"input_blocks.{i}.0" in key
+ and (
+ f"input_blocks.{i}.0.op" not in key
+ and f"input_blocks.{i}.0.time_stack" not in key
+ and f"input_blocks.{i}.0.time_mixer" not in key
+ )
+ ]
+ temporal_resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0.time_stack" in key]
+ # import ipdb; ipdb.set_trace()
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(spatial_resnets)
+ meta_path = {
+ "old": f"input_blocks.{i}.0",
+ "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ paths = renew_resnet_paths(temporal_resnets)
+ meta_path = {
+ "old": f"input_blocks.{i}.0",
+ "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ # TODO resnet time_mixer.mix_factor
+ if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
+ new_checkpoint[
+ f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
+ ] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+ # import ipdb; ipdb.set_trace()
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_spatial = [key for key in resnet_0 if "time_stack" not in key and "time_mixer" not in key]
+ resnet_0_paths = renew_resnet_paths(resnet_0_spatial)
+ # import ipdb; ipdb.set_trace()
+ assign_to_checkpoint(
+ resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
+ )
+
+ resnet_0_temporal = [key for key in resnet_0 if "time_stack" in key and "time_mixer" not in key]
+ resnet_0_paths = renew_resnet_paths(resnet_0_temporal)
+ assign_to_checkpoint(
+ resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
+ )
+
+ resnet_1_spatial = [key for key in resnet_1 if "time_stack" not in key and "time_mixer" not in key]
+ resnet_1_paths = renew_resnet_paths(resnet_1_spatial)
+ assign_to_checkpoint(
+ resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
+ )
+
+ resnet_1_temporal = [key for key in resnet_1 if "time_stack" in key and "time_mixer" not in key]
+ resnet_1_paths = renew_resnet_paths(resnet_1_temporal)
+ assign_to_checkpoint(
+ resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
+ )
+
+ new_checkpoint["mid_block.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
+ "middle_block.0.time_mixer.mix_factor"
+ ]
+ new_checkpoint["mid_block.resnets.1.time_mixer.mix_factor"] = unet_state_dict[
+ "middle_block.2.time_mixer.mix_factor"
+ ]
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ if len(output_block_list) > 1:
+ spatial_resnets = [
+ key
+ for key in output_blocks[i]
+ if f"output_blocks.{i}.0" in key
+ and (f"output_blocks.{i}.0.time_stack" not in key and "time_mixer" not in key)
+ ]
+ # import ipdb; ipdb.set_trace()
+
+ temporal_resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0.time_stack" in key]
+
+ paths = renew_resnet_paths(spatial_resnets)
+ meta_path = {
+ "old": f"output_blocks.{i}.0",
+ "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ paths = renew_resnet_paths(temporal_resnets)
+ meta_path = {
+ "old": f"output_blocks.{i}.0",
+ "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
+ new_checkpoint[
+ f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
+ ] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key]
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ # import ipdb; ipdb.set_trace()
+ meta_path = {
+ "old": f"output_blocks.{i}.1",
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ spatial_layers = [
+ layer for layer in output_block_layers if "time_stack" not in layer and "time_mixer" not in layer
+ ]
+ resnet_0_paths = renew_resnet_paths(spatial_layers, n_shave_prefix_segments=1)
+ # import ipdb; ipdb.set_trace()
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(
+ ["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "spatial_res_block", path["new"]]
+ )
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ temporal_layers = [
+ layer for layer in output_block_layers if "time_stack" in layer and "time_mixer" not in key
+ ]
+ resnet_0_paths = renew_resnet_paths(temporal_layers, n_shave_prefix_segments=1)
+ # import ipdb; ipdb.set_trace()
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(
+ ["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "temporal_res_block", path["new"]]
+ )
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ new_checkpoint["up_blocks.0.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
+ f"output_blocks.{str(i)}.0.time_mixer.mix_factor"
+ ]
+
+ return new_checkpoint
+
+
+def conv_attn_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in attn_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+ elif "proj_attn.weight" in key:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0]
+
+
+def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0, is_temporal=False):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ # Temporal resnet
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = new_item.replace("time_stack.", "temporal_res_block.")
+
+ # Spatial resnet
+ new_item = new_item.replace("conv1", "spatial_res_block.conv1")
+ new_item = new_item.replace("norm1", "spatial_res_block.norm1")
+
+ new_item = new_item.replace("conv2", "spatial_res_block.conv2")
+ new_item = new_item.replace("norm2", "spatial_res_block.norm2")
+
+ new_item = new_item.replace("nin_shortcut", "spatial_res_block.conv_shortcut")
+
+ new_item = new_item.replace("mix_factor", "spatial_res_block.time_mixer.mix_factor")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
+
+ new_item = new_item.replace("q.weight", "to_q.weight")
+ new_item = new_item.replace("q.bias", "to_q.bias")
+
+ new_item = new_item.replace("k.weight", "to_k.weight")
+ new_item = new_item.replace("k.bias", "to_k.bias")
+
+ new_item = new_item.replace("v.weight", "to_v.weight")
+ new_item = new_item.replace("v.bias", "to_v.bias")
+
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def convert_ldm_vae_checkpoint(checkpoint, config):
+ # extract state dict for VAE
+ vae_state_dict = {}
+ keys = list(checkpoint.keys())
+ vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
+ for key in keys:
+ if key.startswith(vae_key):
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
+
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
+ new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.time_mix_conv.weight"]
+ new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.time_mix_conv.bias"]
+
+ # new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
+ # new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
+ # new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
+ # new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
+
+ # Retrieves the keys for the encoder down blocks only
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
+ down_blocks = {
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ # Retrieves the keys for the decoder up blocks only
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
+ up_blocks = {
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
+ }
+
+ for i in range(num_down_blocks):
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.weight"
+ )
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.bias"
+ )
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+
+ resnets = [
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+ ]
+
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.bias"
+ ]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+ return new_checkpoint
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 8a0dc2b923d3..574082c30362 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -76,6 +76,7 @@
[
"AsymmetricAutoencoderKL",
"AutoencoderKL",
+ "AutoencoderKLTemporalDecoder",
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
@@ -92,6 +93,7 @@
"UNet2DModel",
"UNet3DConditionModel",
"UNetMotionModel",
+ "UNetSpatioTemporalConditionModel",
"VQModel",
]
)
@@ -277,8 +279,10 @@
"StableDiffusionXLPipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
+ "StableVideoDiffusionPipeline",
"TextToVideoSDPipeline",
"TextToVideoZeroPipeline",
+ "TextToVideoZeroSDXLPipeline",
"UnCLIPImageVariationPipeline",
"UnCLIPPipeline",
"UniDiffuserModel",
@@ -446,6 +450,7 @@
from .models import (
AsymmetricAutoencoderKL,
AutoencoderKL,
+ AutoencoderKLTemporalDecoder,
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
@@ -462,6 +467,7 @@
UNet2DModel,
UNet3DConditionModel,
UNetMotionModel,
+ UNetSpatioTemporalConditionModel,
VQModel,
)
from .optimization import (
@@ -626,8 +632,10 @@
StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
+ StableVideoDiffusionPipeline,
TextToVideoSDPipeline,
TextToVideoZeroPipeline,
+ TextToVideoZeroSDXLPipeline,
UnCLIPImageVariationPipeline,
UnCLIPPipeline,
UniDiffuserModel,
diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py
index 06eb3af05ee2..3eb7569967a7 100644
--- a/src/diffusers/loaders/lora.py
+++ b/src/diffusers/loaders/lora.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
-import re
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union
@@ -44,13 +43,13 @@
set_adapter_layers,
set_weights_and_activate_adapters,
)
+from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
if is_transformers_available():
- from transformers import CLIPTextModel, CLIPTextModelWithProjection
+ from transformers import PreTrainedModel
- # To be deprecated soon
- from ..models.lora import PatchedLoraProjection
+ from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
if is_accelerate_available():
from accelerate import init_empty_weights
@@ -67,37 +66,10 @@
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
-def text_encoder_attn_modules(text_encoder):
- attn_modules = []
-
- if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
- for i, layer in enumerate(text_encoder.text_model.encoder.layers):
- name = f"text_model.encoder.layers.{i}.self_attn"
- mod = layer.self_attn
- attn_modules.append((name, mod))
- else:
- raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
-
- return attn_modules
-
-
-def text_encoder_mlp_modules(text_encoder):
- mlp_modules = []
-
- if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
- for i, layer in enumerate(text_encoder.text_model.encoder.layers):
- mlp_mod = layer.mlp
- name = f"text_model.encoder.layers.{i}.mlp"
- mlp_modules.append((name, mlp_mod))
- else:
- raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
-
- return mlp_modules
-
-
class LoraLoaderMixin:
r"""
- Load LoRA layers into [`UNet2DConditionModel`] and [`~transformers.CLIPTextModel`].
+ Load LoRA layers into [`UNet2DConditionModel`] and
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
"""
text_encoder_name = TEXT_ENCODER_NAME
@@ -123,28 +95,12 @@ def load_lora_weights(
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- A string (model id of a pretrained model hosted on the Hub), a path to a directory containing the model
- weights, or a [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
- Name for referencing the loaded adapter model. If not specified, it will use `default_{i}` where `i` is
- the total number of adapters being loaded. Must have PEFT installed to use.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(
- "cuda"
- )
- pipeline.load_lora_weights(
- "Yntec/pineappleAnimeMix", weight_name="pineappleAnimeMix_pineapple10.1.safetensors", adapter_name="anime"
- )
- ```
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
"""
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
@@ -182,7 +138,15 @@ def lora_state_dict(
**kwargs,
):
r"""
- Return state dict and network alphas of the LoRA weights.
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
@@ -190,7 +154,8 @@ def lora_state_dict(
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
@@ -226,6 +191,7 @@ def lora_state_dict(
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
+
"""
# Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both.
@@ -322,8 +288,8 @@ def lora_state_dict(
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
- state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
- state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
+ state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
+ state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
return state_dict, network_alphas
@@ -363,109 +329,6 @@ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_ext
weight_name = targeted_files[0]
return weight_name
- @classmethod
- def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
- # 1. get all state_dict_keys
- all_keys = list(state_dict.keys())
- sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
-
- # 2. check if needs remapping, if not return original dict
- is_in_sgm_format = False
- for key in all_keys:
- if any(p in key for p in sgm_patterns):
- is_in_sgm_format = True
- break
-
- if not is_in_sgm_format:
- return state_dict
-
- # 3. Else remap from SGM patterns
- new_state_dict = {}
- inner_block_map = ["resnets", "attentions", "upsamplers"]
-
- # Retrieves # of down, mid and up blocks
- input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
-
- for layer in all_keys:
- if "text" in layer:
- new_state_dict[layer] = state_dict.pop(layer)
- else:
- layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
- if sgm_patterns[0] in layer:
- input_block_ids.add(layer_id)
- elif sgm_patterns[1] in layer:
- middle_block_ids.add(layer_id)
- elif sgm_patterns[2] in layer:
- output_block_ids.add(layer_id)
- else:
- raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
-
- input_blocks = {
- layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
- for layer_id in input_block_ids
- }
- middle_blocks = {
- layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
- for layer_id in middle_block_ids
- }
- output_blocks = {
- layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
- for layer_id in output_block_ids
- }
-
- # Rename keys accordingly
- for i in input_block_ids:
- block_id = (i - 1) // (unet_config.layers_per_block + 1)
- layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
-
- for key in input_blocks[i]:
- inner_block_id = int(key.split(delimiter)[block_slice_pos])
- inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
- inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
- new_key = delimiter.join(
- key.split(delimiter)[: block_slice_pos - 1]
- + [str(block_id), inner_block_key, inner_layers_in_block]
- + key.split(delimiter)[block_slice_pos + 1 :]
- )
- new_state_dict[new_key] = state_dict.pop(key)
-
- for i in middle_block_ids:
- key_part = None
- if i == 0:
- key_part = [inner_block_map[0], "0"]
- elif i == 1:
- key_part = [inner_block_map[1], "0"]
- elif i == 2:
- key_part = [inner_block_map[0], "1"]
- else:
- raise ValueError(f"Invalid middle block id {i}.")
-
- for key in middle_blocks[i]:
- new_key = delimiter.join(
- key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
- )
- new_state_dict[new_key] = state_dict.pop(key)
-
- for i in output_block_ids:
- block_id = i // (unet_config.layers_per_block + 1)
- layer_in_block_id = i % (unet_config.layers_per_block + 1)
-
- for key in output_blocks[i]:
- inner_block_id = int(key.split(delimiter)[block_slice_pos])
- inner_block_key = inner_block_map[inner_block_id]
- inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
- new_key = delimiter.join(
- key.split(delimiter)[: block_slice_pos - 1]
- + [str(block_id), inner_block_key, inner_layers_in_block]
- + key.split(delimiter)[block_slice_pos + 1 :]
- )
- new_state_dict[new_key] = state_dict.pop(key)
-
- if len(state_dict) > 0:
- raise ValueError("At this point all state dict entries have to be converted.")
-
- return new_state_dict
-
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
"""
@@ -502,27 +365,25 @@ def load_lora_into_unet(
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
"""
- Load LoRA layers specified in `state_dict` into `unet`.
+ This will load the LoRA layers specified in `state_dict` into `unet`.
Parameters:
state_dict (`dict`):
- A standard state dict containing the LoRA layer parameters. The keys can either be indexed directly
- into the `unet` or prefixed with an additional `unet`, which can be used to distinguish between text
- encoder LoRA layers.
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
network_alphas (`Dict[str, float]`):
- See
- [`LoRALinearLayer`](https://github.com/huggingface/diffusers/blob/c697f524761abd2314c030221a3ad2f7791eab4e/src/diffusers/models/lora.py#L182)
- for more details.
+ See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
- Only load and not initialize the pretrained weights. This can speedup model loading and also tries to
- not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only
- supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to
- `True` will raise an error.
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
adapter_name (`str`, *optional*):
- Name for referencing the loaded adapter model. If not specified, it will use `default_{i}` where `i` is
- the total number of adapters being loaded.
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -616,27 +477,26 @@ def load_lora_into_text_encoder(
_pipeline=None,
):
"""
- Load LoRA layers specified in `state_dict` into `text_encoder`.
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
Parameters:
state_dict (`dict`):
- A standard state dict containing the LoRA layer parameters. The key should be prefixed with an
- additional `text_encoder` to distinguish between UNet LoRA layers.
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
network_alphas (`Dict[str, float]`):
- See
- [`LoRALinearLayer`](https://github.com/huggingface/diffusers/blob/c697f524761abd2314c030221a3ad2f7791eab4e/src/diffusers/models/lora.py#L182)
- for more details.
+ See `LoRALinearLayer` for more details.
text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into.
prefix (`str`):
Expected prefix of the `text_encoder` in the `state_dict`.
lora_scale (`float`):
- Scale of `LoRALinearLayer`'s output before it is added with the output of the regular LoRA layer.
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
- Only load and not initialize the pretrained weights. This can speedup model loading and also tries to
- not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only
- supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to
- `True` will raise an error.
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
@@ -921,11 +781,11 @@ def save_lora_weights(
safe_serialization: bool = True,
):
r"""
- Save the UNet and text encoder LoRA parameters.
+ Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to (will be created if it doesn't exist).
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `unet`.
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
@@ -936,30 +796,11 @@ def save_lora_weights(
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
- The function to use to save the state dict. Useful during distributed training when you need to replace
- `torch.save` with another method. Can be configured with the environment variable
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or with `pickle`.
-
- Example:
-
- ```py
- from diffusers import StableDiffusionXLPipeline
- from peft.utils import get_peft_model_state_dict
- import torch
-
- pipeline = StableDiffusionXLPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora()
-
- # get and save unet state dict
- unet_state_dict = get_peft_model_state_dict(pipeline.unet, adapter_name="pixel")
- pipeline.save_lora_weights("fused-model", unet_lora_layers=unet_state_dict)
- pipeline.load_lora_weights("fused-model", weight_name="pytorch_lora_weights.safetensors")
- ```
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
# Create a flat dictionary.
state_dict = {}
@@ -1028,186 +869,16 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
- @classmethod
- def _convert_kohya_lora_to_diffusers(cls, state_dict):
- unet_state_dict = {}
- te_state_dict = {}
- te2_state_dict = {}
- network_alphas = {}
-
- # every down weight has a corresponding up weight and potentially an alpha weight
- lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
- for key in lora_keys:
- lora_name = key.split(".")[0]
- lora_name_up = lora_name + ".lora_up.weight"
- lora_name_alpha = lora_name + ".alpha"
-
- if lora_name.startswith("lora_unet_"):
- diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
-
- if "input.blocks" in diffusers_name:
- diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
- else:
- diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
-
- if "middle.block" in diffusers_name:
- diffusers_name = diffusers_name.replace("middle.block", "mid_block")
- else:
- diffusers_name = diffusers_name.replace("mid.block", "mid_block")
- if "output.blocks" in diffusers_name:
- diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
- else:
- diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
-
- diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
- diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
- diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
- diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
- diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
- diffusers_name = diffusers_name.replace("proj.in", "proj_in")
- diffusers_name = diffusers_name.replace("proj.out", "proj_out")
- diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
-
- # SDXL specificity.
- if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
- pattern = r"\.\d+(?=\D*$)"
- diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
- if ".in." in diffusers_name:
- diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
- if ".out." in diffusers_name:
- diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
- if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
- diffusers_name = diffusers_name.replace("op", "conv")
- if "skip" in diffusers_name:
- diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
-
- # LyCORIS specificity.
- if "time.emb.proj" in diffusers_name:
- diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
- if "conv.shortcut" in diffusers_name:
- diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
-
- # General coverage.
- if "transformer_blocks" in diffusers_name:
- if "attn1" in diffusers_name or "attn2" in diffusers_name:
- diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
- diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
- unet_state_dict[diffusers_name] = state_dict.pop(key)
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
- elif "ff" in diffusers_name:
- unet_state_dict[diffusers_name] = state_dict.pop(key)
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
- elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
- unet_state_dict[diffusers_name] = state_dict.pop(key)
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
- else:
- unet_state_dict[diffusers_name] = state_dict.pop(key)
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
-
- elif lora_name.startswith("lora_te_"):
- diffusers_name = key.replace("lora_te_", "").replace("_", ".")
- diffusers_name = diffusers_name.replace("text.model", "text_model")
- diffusers_name = diffusers_name.replace("self.attn", "self_attn")
- diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
- diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
- diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
- diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
- if "self_attn" in diffusers_name:
- te_state_dict[diffusers_name] = state_dict.pop(key)
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
- elif "mlp" in diffusers_name:
- # Be aware that this is the new diffusers convention and the rest of the code might
- # not utilize it yet.
- diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
- te_state_dict[diffusers_name] = state_dict.pop(key)
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
-
- # (sayakpaul): Duplicate code. Needs to be cleaned.
- elif lora_name.startswith("lora_te1_"):
- diffusers_name = key.replace("lora_te1_", "").replace("_", ".")
- diffusers_name = diffusers_name.replace("text.model", "text_model")
- diffusers_name = diffusers_name.replace("self.attn", "self_attn")
- diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
- diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
- diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
- diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
- if "self_attn" in diffusers_name:
- te_state_dict[diffusers_name] = state_dict.pop(key)
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
- elif "mlp" in diffusers_name:
- # Be aware that this is the new diffusers convention and the rest of the code might
- # not utilize it yet.
- diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
- te_state_dict[diffusers_name] = state_dict.pop(key)
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
-
- # (sayakpaul): Duplicate code. Needs to be cleaned.
- elif lora_name.startswith("lora_te2_"):
- diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
- diffusers_name = diffusers_name.replace("text.model", "text_model")
- diffusers_name = diffusers_name.replace("self.attn", "self_attn")
- diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
- diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
- diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
- diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
- if "self_attn" in diffusers_name:
- te2_state_dict[diffusers_name] = state_dict.pop(key)
- te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
- elif "mlp" in diffusers_name:
- # Be aware that this is the new diffusers convention and the rest of the code might
- # not utilize it yet.
- diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
- te2_state_dict[diffusers_name] = state_dict.pop(key)
- te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
-
- # Rename the alphas so that they can be mapped appropriately.
- if lora_name_alpha in state_dict:
- alpha = state_dict.pop(lora_name_alpha).item()
- if lora_name_alpha.startswith("lora_unet_"):
- prefix = "unet."
- elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
- prefix = "text_encoder."
- else:
- prefix = "text_encoder_2."
- new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
- network_alphas.update({new_name: alpha})
-
- if len(state_dict) > 0:
- raise ValueError(
- f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}"
- )
-
- logger.info("Kohya-style checkpoint detected.")
- unet_state_dict = {f"{cls.unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
- te_state_dict = {
- f"{cls.text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()
- }
- te2_state_dict = (
- {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
- if len(te2_state_dict) > 0
- else None
- )
- if te2_state_dict is not None:
- te_state_dict.update(te2_state_dict)
-
- new_state_dict = {**unet_state_dict, **te_state_dict}
- return new_state_dict, network_alphas
-
def unload_lora_weights(self):
"""
- Unload the LoRA parameters from a pipeline.
+ Unloads the LoRA parameters.
Examples:
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.unload_lora_weights()
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
+ >>> pipeline.unload_lora_weights()
+ >>> ...
```
"""
if not USE_PEFT_BACKEND:
@@ -1236,7 +907,7 @@ def fuse_lora(
safe_fusing: bool = False,
):
r"""
- Fuse the LoRA parameters with the original parameters in their corresponding blocks.
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
@@ -1250,23 +921,9 @@ def fuse_lora(
Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
lora_scale (`float`, defaults to 1.0):
- Controls LoRA influence on the outputs.
+ Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for `NaN` values before fusing and if values are `NaN`, then don't fuse
- them.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
"""
if fuse_unet or fuse_text_encoder:
self.num_fused_loras += 1
@@ -1315,7 +972,8 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
r"""
- Unfuse the LoRA parameters from the original parameters in their corresponding blocks.
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
@@ -1328,20 +986,6 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- pipeline.unfuse_lora()
- ```
"""
if unfuse_unet:
if not USE_PEFT_BACKEND:
@@ -1393,32 +1037,16 @@ def set_adapters_for_text_encoder(
text_encoder_weights: List[float] = None,
):
"""
- Set the currently active adapter for use in the text encoder.
+ Sets the adapter layers for the text encoder.
Args:
adapter_names (`List[str]` or `str`):
- The adapter to activate.
+ The names of the adapters to use.
text_encoder (`torch.nn.Module`, *optional*):
- The text encoder module to activate the adapter layers for. If `None`, it will try to get the
- `text_encoder` attribute.
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
+ attribute.
text_encoder_weights (`List[float]`, *optional*):
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.load_lora_weights(
- "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
- )
- pipeline.set_adapters_for_text_encoder("pixel")
- ```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1444,27 +1072,14 @@ def process_weights(adapter_names, weights):
)
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
- def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821
+ def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
"""
- Disable the text encoder's LoRA layers.
+ Disables the LoRA layers for the text encoder.
Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the
`text_encoder` attribute.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.disable_lora_for_text_encoder()
- ```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1474,27 +1089,14 @@ def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"
raise ValueError("Text Encoder not found.")
set_adapter_layers(text_encoder, enabled=False)
- def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821
+ def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
"""
- Enables the text encoder's LoRA layers.
+ Enables the LoRA layers for the text encoder.
Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
attribute.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.enable_lora_for_text_encoder()
- ```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1545,24 +1147,10 @@ def enable_lora(self):
def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
- Delete an adapter's LoRA layers from the UNet and text encoder(s).
-
Args:
+ Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
adapter_names (`Union[List[str], str]`):
- The names (single string or list of strings) of the adapter to delete.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.delete_adapters("pixel")
- ```
+ The names of the adapter to delete. Can be a single string or a list of strings
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1582,7 +1170,7 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
def get_active_adapters(self) -> List[str]:
"""
- Get a list of currently active adapters.
+ Gets the list of the current active adapters.
Example:
@@ -1614,22 +1202,7 @@ def get_active_adapters(self) -> List[str]:
def get_list_adapters(self) -> Dict[str, List[str]]:
"""
- Get a list of all currently available adapters for each component in the pipeline.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- ).to("cuda")
- pipeline.load_lora_weights(
- "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
- )
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.get_list_adapters()
- ```
+ Gets the current list of all available adapters in the pipeline.
"""
if not USE_PEFT_BACKEND:
raise ValueError(
@@ -1651,27 +1224,14 @@ def get_list_adapters(self) -> Dict[str, List[str]]:
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
"""
- Move a LoRA to a target device. Useful for offloading a LoRA to the CPU in case you want to load multiple
- adapters and free some GPU memory.
+ Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
+ you want to load multiple adapters and free some GPU memory.
Args:
adapter_names (`List[str]`):
- List of adapters to send to device.
+ List of adapters to send device to.
device (`Union[torch.device, str, int]`):
- Device (can be a `torch.device`, `str` or `int`) to place adapters on.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.set_lora_device(["pixel"], device="cuda")
- ```
+ Device to send the adapters to. Can be either a torch device, a str or an integer.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1703,7 +1263,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
- """This class overrides [`LoraLoaderMixin`] with LoRA loading/saving code that's specific to SDXL."""
+ """This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
# Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(
@@ -1728,26 +1288,12 @@ def load_lora_weights(
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- A string (model id of a pretrained model hosted on the Hub), a path to a directory containing the model
- weights, or a [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
- kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
- Name for referencing the loaded adapter model. If not specified, it will use `default_{i}` where `i` is
- the total number of adapters being loaded. Must have PEFT installed to use.
-
- Example:
-
- ```py
- from diffusers import StableDiffusionXLPipeline
- import torch
-
- pipeline = StableDiffusionXLPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- ```
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
new file mode 100644
index 000000000000..4a89fc20b56b
--- /dev/null
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -0,0 +1,284 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
+ # 1. get all state_dict_keys
+ all_keys = list(state_dict.keys())
+ sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
+
+ # 2. check if needs remapping, if not return original dict
+ is_in_sgm_format = False
+ for key in all_keys:
+ if any(p in key for p in sgm_patterns):
+ is_in_sgm_format = True
+ break
+
+ if not is_in_sgm_format:
+ return state_dict
+
+ # 3. Else remap from SGM patterns
+ new_state_dict = {}
+ inner_block_map = ["resnets", "attentions", "upsamplers"]
+
+ # Retrieves # of down, mid and up blocks
+ input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
+
+ for layer in all_keys:
+ if "text" in layer:
+ new_state_dict[layer] = state_dict.pop(layer)
+ else:
+ layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
+ if sgm_patterns[0] in layer:
+ input_block_ids.add(layer_id)
+ elif sgm_patterns[1] in layer:
+ middle_block_ids.add(layer_id)
+ elif sgm_patterns[2] in layer:
+ output_block_ids.add(layer_id)
+ else:
+ raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
+
+ input_blocks = {
+ layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
+ for layer_id in input_block_ids
+ }
+ middle_blocks = {
+ layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
+ for layer_id in middle_block_ids
+ }
+ output_blocks = {
+ layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
+ for layer_id in output_block_ids
+ }
+
+ # Rename keys accordingly
+ for i in input_block_ids:
+ block_id = (i - 1) // (unet_config.layers_per_block + 1)
+ layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
+
+ for key in input_blocks[i]:
+ inner_block_id = int(key.split(delimiter)[block_slice_pos])
+ inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
+ inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
+ new_key = delimiter.join(
+ key.split(delimiter)[: block_slice_pos - 1]
+ + [str(block_id), inner_block_key, inner_layers_in_block]
+ + key.split(delimiter)[block_slice_pos + 1 :]
+ )
+ new_state_dict[new_key] = state_dict.pop(key)
+
+ for i in middle_block_ids:
+ key_part = None
+ if i == 0:
+ key_part = [inner_block_map[0], "0"]
+ elif i == 1:
+ key_part = [inner_block_map[1], "0"]
+ elif i == 2:
+ key_part = [inner_block_map[0], "1"]
+ else:
+ raise ValueError(f"Invalid middle block id {i}.")
+
+ for key in middle_blocks[i]:
+ new_key = delimiter.join(
+ key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
+ )
+ new_state_dict[new_key] = state_dict.pop(key)
+
+ for i in output_block_ids:
+ block_id = i // (unet_config.layers_per_block + 1)
+ layer_in_block_id = i % (unet_config.layers_per_block + 1)
+
+ for key in output_blocks[i]:
+ inner_block_id = int(key.split(delimiter)[block_slice_pos])
+ inner_block_key = inner_block_map[inner_block_id]
+ inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
+ new_key = delimiter.join(
+ key.split(delimiter)[: block_slice_pos - 1]
+ + [str(block_id), inner_block_key, inner_layers_in_block]
+ + key.split(delimiter)[block_slice_pos + 1 :]
+ )
+ new_state_dict[new_key] = state_dict.pop(key)
+
+ if len(state_dict) > 0:
+ raise ValueError("At this point all state dict entries have to be converted.")
+
+ return new_state_dict
+
+
+def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
+ unet_state_dict = {}
+ te_state_dict = {}
+ te2_state_dict = {}
+ network_alphas = {}
+
+ # every down weight has a corresponding up weight and potentially an alpha weight
+ lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
+ for key in lora_keys:
+ lora_name = key.split(".")[0]
+ lora_name_up = lora_name + ".lora_up.weight"
+ lora_name_alpha = lora_name + ".alpha"
+
+ if lora_name.startswith("lora_unet_"):
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
+
+ if "input.blocks" in diffusers_name:
+ diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
+ else:
+ diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
+
+ if "middle.block" in diffusers_name:
+ diffusers_name = diffusers_name.replace("middle.block", "mid_block")
+ else:
+ diffusers_name = diffusers_name.replace("mid.block", "mid_block")
+ if "output.blocks" in diffusers_name:
+ diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
+ else:
+ diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
+
+ diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
+ diffusers_name = diffusers_name.replace("proj.in", "proj_in")
+ diffusers_name = diffusers_name.replace("proj.out", "proj_out")
+ diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
+
+ # SDXL specificity.
+ if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
+ pattern = r"\.\d+(?=\D*$)"
+ diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
+ if ".in." in diffusers_name:
+ diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
+ if ".out." in diffusers_name:
+ diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
+ if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
+ diffusers_name = diffusers_name.replace("op", "conv")
+ if "skip" in diffusers_name:
+ diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
+
+ # LyCORIS specificity.
+ if "time.emb.proj" in diffusers_name:
+ diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
+ if "conv.shortcut" in diffusers_name:
+ diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
+
+ # General coverage.
+ if "transformer_blocks" in diffusers_name:
+ if "attn1" in diffusers_name or "attn2" in diffusers_name:
+ diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
+ diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+ elif "ff" in diffusers_name:
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+ elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+ else:
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+
+ elif lora_name.startswith("lora_te_"):
+ diffusers_name = key.replace("lora_te_", "").replace("_", ".")
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
+ if "self_attn" in diffusers_name:
+ te_state_dict[diffusers_name] = state_dict.pop(key)
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+ elif "mlp" in diffusers_name:
+ # Be aware that this is the new diffusers convention and the rest of the code might
+ # not utilize it yet.
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
+ te_state_dict[diffusers_name] = state_dict.pop(key)
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+
+ # (sayakpaul): Duplicate code. Needs to be cleaned.
+ elif lora_name.startswith("lora_te1_"):
+ diffusers_name = key.replace("lora_te1_", "").replace("_", ".")
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
+ if "self_attn" in diffusers_name:
+ te_state_dict[diffusers_name] = state_dict.pop(key)
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+ elif "mlp" in diffusers_name:
+ # Be aware that this is the new diffusers convention and the rest of the code might
+ # not utilize it yet.
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
+ te_state_dict[diffusers_name] = state_dict.pop(key)
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+
+ # (sayakpaul): Duplicate code. Needs to be cleaned.
+ elif lora_name.startswith("lora_te2_"):
+ diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
+ if "self_attn" in diffusers_name:
+ te2_state_dict[diffusers_name] = state_dict.pop(key)
+ te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+ elif "mlp" in diffusers_name:
+ # Be aware that this is the new diffusers convention and the rest of the code might
+ # not utilize it yet.
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
+ te2_state_dict[diffusers_name] = state_dict.pop(key)
+ te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+
+ # Rename the alphas so that they can be mapped appropriately.
+ if lora_name_alpha in state_dict:
+ alpha = state_dict.pop(lora_name_alpha).item()
+ if lora_name_alpha.startswith("lora_unet_"):
+ prefix = "unet."
+ elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
+ prefix = "text_encoder."
+ else:
+ prefix = "text_encoder_2."
+ new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
+ network_alphas.update({new_name: alpha})
+
+ if len(state_dict) > 0:
+ raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}")
+
+ logger.info("Kohya-style checkpoint detected.")
+ unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
+ te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
+ te2_state_dict = (
+ {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
+ if len(te2_state_dict) > 0
+ else None
+ )
+ if te2_state_dict is not None:
+ te_state_dict.update(te2_state_dict)
+
+ new_state_dict = {**unet_state_dict, **te_state_dict}
+ return new_state_dict, network_alphas
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index de2e2848b848..839045001bb0 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -14,7 +14,12 @@
from typing import TYPE_CHECKING
-from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
+from ..utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ _LazyModule,
+ is_flax_available,
+ is_torch_available,
+)
_import_structure = {}
@@ -23,6 +28,7 @@
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
+ _import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
@@ -38,6 +44,7 @@
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
+ _import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["vq_model"] = ["VQModel"]
if is_flax_available():
@@ -51,6 +58,7 @@
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
+ from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
@@ -66,6 +74,7 @@
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandi3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
+ from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .vq_model import VQModel
if is_flax_available():
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index 0c4c5de6e31a..f02b5e249eee 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -25,6 +25,31 @@
from .normalization import AdaLayerNorm, AdaLayerNormZero
+def _chunked_feed_forward(
+ ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
+):
+ # "feed_forward_chunk_size" can be used to save memory
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
+ if lora_scale is None:
+ ff_output = torch.cat(
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
+ dim=chunk_dim,
+ )
+ else:
+ # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
+ ff_output = torch.cat(
+ [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
+ dim=chunk_dim,
+ )
+
+ return ff_output
+
+
@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
r"""
@@ -194,7 +219,12 @@ def __init__(
if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ )
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
@@ -208,7 +238,7 @@ def __init__(
self._chunk_size = None
self._chunk_dim = 0
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
@@ -311,18 +341,8 @@ def forward(
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
- raise ValueError(
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
- )
-
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
- ff_output = torch.cat(
- [
- self.ff(hid_slice, scale=lora_scale)
- for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
- ],
- dim=self._chunk_dim,
+ ff_output = _chunked_feed_forward(
+ self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
)
else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
@@ -339,6 +359,137 @@ def forward(
return hidden_states
+@maybe_allow_in_graph
+class TemporalBasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block for video like data.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ time_mix_inner_dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.is_res = dim == time_mix_inner_dim
+
+ self.norm_in = nn.LayerNorm(dim)
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ self.norm_in = nn.LayerNorm(dim)
+ self.ff_in = FeedForward(
+ dim,
+ dim_out=time_mix_inner_dim,
+ activation_fn="geglu",
+ )
+
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
+ self.attn1 = Attention(
+ query_dim=time_mix_inner_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ cross_attention_dim=None,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
+ self.attn2 = Attention(
+ query_dim=time_mix_inner_dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = None
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
+ self._chunk_dim = 1
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ num_frames: int,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ batch_frames, seq_length, channels = hidden_states.shape
+ batch_size = batch_frames // num_frames
+
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
+
+ residual = hidden_states
+ hidden_states = self.norm_in(hidden_states)
+
+ if self._chunk_size is not None:
+ hidden_states = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ hidden_states = self.ff_in(hidden_states)
+
+ if self.is_res:
+ hidden_states = hidden_states + residual
+
+ norm_hidden_states = self.norm1(hidden_states)
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
+ hidden_states = attn_output + hidden_states
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ norm_hidden_states = self.norm2(hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self._chunk_size is not None:
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.is_res:
+ hidden_states = ff_output + hidden_states
+ else:
+ hidden_states = ff_output
+
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
+
+ return hidden_states
+
+
class FeedForward(nn.Module):
r"""
A feed-forward layer.
diff --git a/src/diffusers/models/autoencoder_asym_kl.py b/src/diffusers/models/autoencoder_asym_kl.py
index 818e181fcdf0..678e47234096 100644
--- a/src/diffusers/models/autoencoder_asym_kl.py
+++ b/src/diffusers/models/autoencoder_asym_kl.py
@@ -18,7 +18,7 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils.accelerate_utils import apply_forward_hook
-from .autoencoder_kl import AutoencoderKLOutput
+from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py
index 9003d982b32f..464bff9189dd 100644
--- a/src/diffusers/models/autoencoder_kl.py
+++ b/src/diffusers/models/autoencoder_kl.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import torch
@@ -19,7 +18,6 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
-from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -28,24 +26,11 @@
AttnAddedKVProcessor,
AttnProcessor,
)
+from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
-@dataclass
-class AutoencoderKLOutput(BaseOutput):
- """
- Output of AutoencoderKL encoding method.
-
- Args:
- latent_dist (`DiagonalGaussianDistribution`):
- Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
- `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
- """
-
- latent_dist: "DiagonalGaussianDistribution"
-
-
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
diff --git a/src/diffusers/models/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoder_kl_temporal_decoder.py
new file mode 100644
index 000000000000..176b6e0df924
--- /dev/null
+++ b/src/diffusers/models/autoencoder_kl_temporal_decoder.py
@@ -0,0 +1,402 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..loaders import FromOriginalVAEMixin
+from ..utils import is_torch_version
+from ..utils.accelerate_utils import apply_forward_hook
+from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
+from .modeling_outputs import AutoencoderKLOutput
+from .modeling_utils import ModelMixin
+from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
+from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
+
+
+class TemporalDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 4,
+ out_channels: int = 3,
+ block_out_channels: Tuple[int] = (128, 256, 512, 512),
+ layers_per_block: int = 2,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
+ self.mid_block = MidBlockTemporalDecoder(
+ num_layers=self.layers_per_block,
+ in_channels=block_out_channels[-1],
+ out_channels=block_out_channels[-1],
+ attention_head_dim=block_out_channels[-1],
+ )
+
+ # up
+ self.up_blocks = nn.ModuleList([])
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i in range(len(block_out_channels)):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+
+ is_final_block = i == len(block_out_channels) - 1
+ up_block = UpBlockTemporalDecoder(
+ num_layers=self.layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ add_upsample=not is_final_block,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6)
+
+ self.conv_act = nn.SiLU()
+ self.conv_out = torch.nn.Conv2d(
+ in_channels=block_out_channels[0],
+ out_channels=out_channels,
+ kernel_size=3,
+ padding=1,
+ )
+
+ conv_out_kernel_size = (3, 1, 1)
+ padding = [int(k // 2) for k in conv_out_kernel_size]
+ self.time_conv_out = torch.nn.Conv3d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=conv_out_kernel_size,
+ padding=padding,
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ image_only_indicator: torch.FloatTensor,
+ num_frames: int = 1,
+ ) -> torch.FloatTensor:
+ r"""The forward method of the `Decoder` class."""
+
+ sample = self.conv_in(sample)
+
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block),
+ sample,
+ image_only_indicator,
+ use_reentrant=False,
+ )
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(up_block),
+ sample,
+ image_only_indicator,
+ use_reentrant=False,
+ )
+ else:
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block),
+ sample,
+ image_only_indicator,
+ )
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(up_block),
+ sample,
+ image_only_indicator,
+ )
+ else:
+ # middle
+ sample = self.mid_block(sample, image_only_indicator=image_only_indicator)
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = up_block(sample, image_only_indicator=image_only_indicator)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ batch_frames, channels, height, width = sample.shape
+ batch_size = batch_frames // num_frames
+ sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
+ sample = self.time_conv_out(sample)
+
+ sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
+
+ return sample
+
+
+class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
+ r"""
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ Tuple of downsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
+ Tuple of block output channels.
+ layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block.
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ force_upcast (`bool`, *optional*, default to `True`):
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ latent_channels: int = 4,
+ sample_size: int = 32,
+ scaling_factor: float = 0.18215,
+ force_upcast: float = True,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ double_z=True,
+ )
+
+ # pass init params to Decoder
+ self.decoder = TemporalDecoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ )
+
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
+
+ sample_size = (
+ self.config.sample_size[0]
+ if isinstance(self.config.sample_size, (list, tuple))
+ else self.config.sample_size
+ )
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
+ self.tile_overlap_factor = 0.25
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (Encoder, TemporalDecoder)):
+ module.gradient_checkpointing = value
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor, _remove_lora=_remove_lora)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor, _remove_lora=True)
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.FloatTensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ @apply_forward_hook
+ def decode(
+ self,
+ z: torch.FloatTensor,
+ num_frames: int,
+ return_dict: bool = True,
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.FloatTensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+
+ """
+ batch_size = z.shape[0] // num_frames
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
+ decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ num_frames: int = 1,
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+
+ dec = self.decode(z, num_frames=num_frames).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py
index 220e34593c23..3139bb2a5c6c 100644
--- a/src/diffusers/models/controlnet.py
+++ b/src/diffusers/models/controlnet.py
@@ -30,12 +30,7 @@
)
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
-from .unet_2d_blocks import (
- CrossAttnDownBlock2D,
- DownBlock2D,
- UNetMidBlock2DCrossAttn,
- get_down_block,
-)
+from .unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block
from .unet_2d_condition import UNet2DConditionModel
@@ -191,6 +186,7 @@ def __init__(
"CrossAttnDownBlock2D",
"DownBlock2D",
),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
@@ -409,20 +405,35 @@ def __init__(
controlnet_block = zero_module(controlnet_block)
self.controlnet_mid_block = controlnet_block
- self.mid_block = UNetMidBlock2DCrossAttn(
- transformer_layers_per_block=transformer_layers_per_block[-1],
- in_channels=mid_block_channel,
- temb_channels=time_embed_dim,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_time_scale_shift=resnet_time_scale_shift,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads[-1],
- resnet_groups=norm_num_groups,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- )
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=mid_block_channel,
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ elif mid_block_type == "UNetMidBlock2D":
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ num_layers=0,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ add_attention=False,
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
@classmethod
def from_unet(
@@ -431,6 +442,7 @@ def from_unet(
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
+ conditioning_channels: int = 3,
):
r"""
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
@@ -477,8 +489,10 @@ def from_unet(
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
+ mid_block_type=unet.config.mid_block_type,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
)
if load_weights_from_unet:
@@ -797,13 +811,16 @@ def forward(
# 4. mid
if self.mid_block is not None:
- sample = self.mid_block(
- sample,
- emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- )
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
# 5. Control net blocks
diff --git a/src/diffusers/models/modeling_outputs.py b/src/diffusers/models/modeling_outputs.py
new file mode 100644
index 000000000000..8dfee5fec181
--- /dev/null
+++ b/src/diffusers/models/modeling_outputs.py
@@ -0,0 +1,17 @@
+from dataclasses import dataclass
+
+from ..utils import BaseOutput
+
+
+@dataclass
+class AutoencoderKLOutput(BaseOutput):
+ """
+ Output of AutoencoderKL encoding method.
+
+ Args:
+ latent_dist (`DiagonalGaussianDistribution`):
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
+ """
+
+ latent_dist: "DiagonalGaussianDistribution" # noqa: F821
diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py
index 7a48d343a531..970d2be05b7a 100644
--- a/src/diffusers/models/resnet.py
+++ b/src/diffusers/models/resnet.py
@@ -165,7 +165,10 @@ def __init__(
self.Conv2d_0 = conv
def forward(
- self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, scale: float = 1.0
+ self,
+ hidden_states: torch.FloatTensor,
+ output_size: Optional[int] = None,
+ scale: float = 1.0,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
@@ -379,7 +382,11 @@ def _upsample_2d(
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
inverse_conv = F.conv_transpose2d(
- hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
+ hidden_states,
+ weight,
+ stride=stride,
+ output_padding=output_padding,
+ padding=0,
)
output = upfirdn2d_native(
@@ -530,7 +537,14 @@ def __init__(self, pad_mode: str = "reflect"):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
- weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
+ weight = inputs.new_zeros(
+ [
+ inputs.shape[1],
+ inputs.shape[1],
+ self.kernel.shape[0],
+ self.kernel.shape[1],
+ ]
+ )
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
@@ -553,7 +567,14 @@ def __init__(self, pad_mode: str = "reflect"):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
- weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
+ weight = inputs.new_zeros(
+ [
+ inputs.shape[1],
+ inputs.shape[1],
+ self.kernel.shape[0],
+ self.kernel.shape[1],
+ ]
+ )
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
@@ -690,11 +711,19 @@ def __init__(
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
- in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
+ in_channels,
+ conv_2d_out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=conv_shortcut_bias,
)
def forward(
- self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0
+ self,
+ input_tensor: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor
@@ -866,7 +895,10 @@ def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
def upsample_2d(
- hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
+ hidden_states: torch.FloatTensor,
+ kernel: Optional[torch.FloatTensor] = None,
+ factor: int = 2,
+ gain: float = 1,
) -> torch.FloatTensor:
r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
@@ -910,7 +942,10 @@ def upsample_2d(
def downsample_2d(
- hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
+ hidden_states: torch.FloatTensor,
+ kernel: Optional[torch.FloatTensor] = None,
+ factor: int = 2,
+ gain: float = 1,
) -> torch.FloatTensor:
r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
@@ -946,13 +981,20 @@ def downsample_2d(
kernel = kernel * gain
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
- hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
+ hidden_states,
+ kernel.to(device=hidden_states.device),
+ down=factor,
+ pad=((pad_value + 1) // 2, pad_value // 2),
)
return output
def upfirdn2d_native(
- tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
+ tensor: torch.Tensor,
+ kernel: torch.Tensor,
+ up: int = 1,
+ down: int = 1,
+ pad: Tuple[int, int] = (0, 0),
) -> torch.Tensor:
up_x = up_y = up
down_x = down_y = down
@@ -1008,7 +1050,13 @@ class TemporalConvLayer(nn.Module):
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
"""
- def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, norm_num_groups: int = 32):
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: Optional[int] = None,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ ):
super().__init__()
out_dim = out_dim or in_dim
self.in_dim = in_dim
@@ -1016,7 +1064,9 @@ def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float =
# conv layers
self.conv1 = nn.Sequential(
- nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
+ nn.GroupNorm(norm_num_groups, in_dim),
+ nn.SiLU(),
+ nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
)
self.conv2 = nn.Sequential(
nn.GroupNorm(norm_num_groups, out_dim),
@@ -1058,3 +1108,261 @@ def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Ten
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
)
return hidden_states
+
+
+class TemporalResnetBlock(nn.Module):
+ r"""
+ A Resnet block.
+
+ Parameters:
+ in_channels (`int`): The number of channels in the input.
+ out_channels (`int`, *optional*, default to be `None`):
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ temb_channels: int = 512,
+ eps: float = 1e-6,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+
+ kernel_size = (3, 1, 1)
+ padding = [k // 2 for k in kernel_size]
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True)
+ self.conv1 = nn.Conv3d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding,
+ )
+
+ if temb_channels is not None:
+ self.time_emb_proj = nn.Linear(temb_channels, out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True)
+
+ self.dropout = torch.nn.Dropout(0.0)
+ self.conv2 = nn.Conv3d(
+ out_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding,
+ )
+
+ self.nonlinearity = get_activation("silu")
+
+ self.use_in_shortcut = self.in_channels != out_channels
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = nn.Conv3d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ if self.time_emb_proj is not None:
+ temb = self.nonlinearity(temb)
+ temb = self.time_emb_proj(temb)[:, :, :, None, None]
+ temb = temb.permute(0, 2, 1, 3, 4)
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = input_tensor + hidden_states
+
+ return output_tensor
+
+
+# VideoResBlock
+class SpatioTemporalResBlock(nn.Module):
+ r"""
+ A SpatioTemporal Resnet block.
+
+ Parameters:
+ in_channels (`int`): The number of channels in the input.
+ out_channels (`int`, *optional*, default to be `None`):
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.
+ temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.
+ merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.
+ merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
+ The merge strategy to use for the temporal mixing.
+ switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
+ If `True`, switch the spatial and temporal mixing.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ temb_channels: int = 512,
+ eps: float = 1e-6,
+ temporal_eps: Optional[float] = None,
+ merge_factor: float = 0.5,
+ merge_strategy="learned_with_images",
+ switch_spatial_to_temporal_mix: bool = False,
+ ):
+ super().__init__()
+
+ self.spatial_res_block = ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=eps,
+ )
+
+ self.temporal_res_block = TemporalResnetBlock(
+ in_channels=out_channels if out_channels is not None else in_channels,
+ out_channels=out_channels if out_channels is not None else in_channels,
+ temb_channels=temb_channels,
+ eps=temporal_eps if temporal_eps is not None else eps,
+ )
+
+ self.time_mixer = AlphaBlender(
+ alpha=merge_factor,
+ merge_strategy=merge_strategy,
+ switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ):
+ num_frames = image_only_indicator.shape[-1]
+ hidden_states = self.spatial_res_block(hidden_states, temb)
+
+ batch_frames, channels, height, width = hidden_states.shape
+ batch_size = batch_frames // num_frames
+
+ hidden_states_mix = (
+ hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
+ )
+ hidden_states = (
+ hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
+ )
+
+ if temb is not None:
+ temb = temb.reshape(batch_size, num_frames, -1)
+
+ hidden_states = self.temporal_res_block(hidden_states, temb)
+ hidden_states = self.time_mixer(
+ x_spatial=hidden_states_mix,
+ x_temporal=hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
+ return hidden_states
+
+
+class AlphaBlender(nn.Module):
+ r"""
+ A module to blend spatial and temporal features.
+
+ Parameters:
+ alpha (`float`): The initial value of the blending factor.
+ merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
+ The merge strategy to use for the temporal mixing.
+ switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
+ If `True`, switch the spatial and temporal mixing.
+ """
+
+ strategies = ["learned", "fixed", "learned_with_images"]
+
+ def __init__(
+ self,
+ alpha: float,
+ merge_strategy: str = "learned_with_images",
+ switch_spatial_to_temporal_mix: bool = False,
+ ):
+ super().__init__()
+ self.merge_strategy = merge_strategy
+ self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE
+
+ if merge_strategy not in self.strategies:
+ raise ValueError(f"merge_strategy needs to be in {self.strategies}")
+
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
+ self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
+ else:
+ raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
+
+ def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
+ if self.merge_strategy == "fixed":
+ alpha = self.mix_factor
+
+ elif self.merge_strategy == "learned":
+ alpha = torch.sigmoid(self.mix_factor)
+
+ elif self.merge_strategy == "learned_with_images":
+ if image_only_indicator is None:
+ raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy")
+
+ alpha = torch.where(
+ image_only_indicator.bool(),
+ torch.ones(1, 1, device=image_only_indicator.device),
+ torch.sigmoid(self.mix_factor)[..., None],
+ )
+
+ # (batch, channel, frames, height, width)
+ if ndims == 5:
+ alpha = alpha[:, None, :, None, None]
+ # (batch*frames, height*width, channels)
+ elif ndims == 3:
+ alpha = alpha.reshape(-1)[:, None, None]
+ else:
+ raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
+
+ else:
+ raise NotImplementedError
+
+ return alpha
+
+ def forward(
+ self,
+ x_spatial: torch.Tensor,
+ x_temporal: torch.Tensor,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
+ alpha = alpha.to(x_spatial.dtype)
+
+ if self.switch_spatial_to_temporal_mix:
+ alpha = 1.0 - alpha
+
+ x = alpha * x_spatial + (1.0 - alpha) * x_temporal
+ return x
diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py
index 2e053d70eaa7..26e899a9b908 100644
--- a/src/diffusers/models/transformer_temporal.py
+++ b/src/diffusers/models/transformer_temporal.py
@@ -19,8 +19,10 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
-from .attention import BasicTransformerBlock
+from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
+from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
+from .resnet import AlphaBlender
@dataclass
@@ -195,3 +197,183 @@ def forward(
return (output,)
return TransformerTemporalModelOutput(sample=output)
+
+
+class TransformerSpatioTemporalModel(nn.Module):
+ """
+ A Transformer model for video-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ out_channels (`int`, *optional*):
+ The number of channels in the output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ """
+
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: int = 320,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ cross_attention_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+
+ inner_dim = num_attention_heads * attention_head_dim
+ self.inner_dim = inner_dim
+
+ # 2. Define input layers
+ self.in_channels = in_channels
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ time_mix_inner_dim = inner_dim
+ self.temporal_transformer_blocks = nn.ModuleList(
+ [
+ TemporalBasicTransformerBlock(
+ inner_dim,
+ time_mix_inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ time_embed_dim = in_channels * 4
+ self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
+ self.time_proj = Timesteps(in_channels, True, 0)
+ self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
+
+ # 4. Define output layers
+ self.out_channels = in_channels if out_channels is None else out_channels
+ # TODO: should use out_channels for continuous projections
+ self.proj_out = nn.Linear(inner_dim, in_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ):
+ """
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input hidden_states.
+ num_frames (`int`):
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
+ A tensor indicating whether the input contains only images. 1 indicates that the input contains only
+ images, 0 indicates that the input contains video frames.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
+ returned, otherwise a `tuple` where the first element is the sample tensor.
+ """
+ # 1. Input
+ batch_frames, _, height, width = hidden_states.shape
+ num_frames = image_only_indicator.shape[-1]
+ batch_size = batch_frames // num_frames
+
+ time_context = encoder_hidden_states
+ time_context_first_timestep = time_context[None, :].reshape(
+ batch_size, num_frames, -1, time_context.shape[-1]
+ )[:, 0]
+ time_context = time_context_first_timestep[None, :].broadcast_to(
+ height * width, batch_size, 1, time_context.shape[-1]
+ )
+ time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
+
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
+ num_frames_emb = num_frames_emb.repeat(batch_size, 1)
+ num_frames_emb = num_frames_emb.reshape(-1)
+ t_emb = self.time_proj(num_frames_emb)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+
+ emb = self.time_pos_embed(t_emb)
+ emb = emb[:, None, :]
+
+ # 2. Blocks
+ for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ block,
+ hidden_states,
+ None,
+ encoder_hidden_states,
+ None,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ hidden_states_mix = hidden_states
+ hidden_states_mix = hidden_states_mix + emb
+
+ hidden_states_mix = temporal_block(
+ hidden_states_mix,
+ num_frames=num_frames,
+ encoder_hidden_states=time_context,
+ )
+ hidden_states = self.time_mixer(
+ x_spatial=hidden_states,
+ x_temporal=hidden_states_mix,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # 3. Output
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+ output = hidden_states + residual
+
+ if not return_dict:
+ return (output,)
+
+ return TransformerTemporalModelOutput(sample=output)
diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py
index 767ab846d5dc..e9c505c347b0 100644
--- a/src/diffusers/models/unet_3d_blocks.py
+++ b/src/diffusers/models/unet_3d_blocks.py
@@ -19,10 +19,20 @@
from ..utils import is_torch_version
from ..utils.torch_utils import apply_freeu
+from .attention import Attention
from .dual_transformer_2d import DualTransformer2DModel
-from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
+from .resnet import (
+ Downsample2D,
+ ResnetBlock2D,
+ SpatioTemporalResBlock,
+ TemporalConvLayer,
+ Upsample2D,
+)
from .transformer_2d import Transformer2DModel
-from .transformer_temporal import TransformerTemporalModel
+from .transformer_temporal import (
+ TransformerSpatioTemporalModel,
+ TransformerTemporalModel,
+)
def get_down_block(
@@ -45,7 +55,15 @@ def get_down_block(
resnet_time_scale_shift: str = "default",
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
-) -> Union["DownBlock3D", "CrossAttnDownBlock3D", "DownBlockMotion", "CrossAttnDownBlockMotion"]:
+ transformer_layers_per_block: int = 1,
+) -> Union[
+ "DownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlockMotion",
+ "CrossAttnDownBlockMotion",
+ "DownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+]:
if down_block_type == "DownBlock3D":
return DownBlock3D(
num_layers=num_layers,
@@ -118,6 +136,29 @@ def get_down_block(
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
)
+ elif down_block_type == "DownBlockSpatioTemporal":
+ # added for SDV
+ return DownBlockSpatioTemporal(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ )
+ elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
+ # added for SDV
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
+ return CrossAttnDownBlockSpatioTemporal(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ add_downsample=add_downsample,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ )
raise ValueError(f"{down_block_type} does not exist.")
@@ -144,7 +185,16 @@ def get_up_block(
temporal_num_attention_heads: int = 8,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
-) -> Union["UpBlock3D", "CrossAttnUpBlock3D", "UpBlockMotion", "CrossAttnUpBlockMotion"]:
+ transformer_layers_per_block: int = 1,
+ dropout: float = 0.0,
+) -> Union[
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "UpBlockMotion",
+ "CrossAttnUpBlockMotion",
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+]:
if up_block_type == "UpBlock3D":
return UpBlock3D(
num_layers=num_layers,
@@ -221,6 +271,34 @@ def get_up_block(
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
)
+ elif up_block_type == "UpBlockSpatioTemporal":
+ # added for SDV
+ return UpBlockSpatioTemporal(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ add_upsample=add_upsample,
+ )
+ elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
+ # added for SDV
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
+ return CrossAttnUpBlockSpatioTemporal(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ add_upsample=add_upsample,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ resolution_idx=resolution_idx,
+ )
+
raise ValueError(f"{up_block_type} does not exist.")
@@ -347,7 +425,10 @@ def forward(
return_dict=False,
)[0]
hidden_states = temp_attn(
- hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
+ hidden_states,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
@@ -443,7 +524,11 @@ def __init__(
self.downsamplers = nn.ModuleList(
[
Downsample2D(
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
)
]
)
@@ -476,7 +561,10 @@ def forward(
return_dict=False,
)[0]
hidden_states = temp_attn(
- hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
+ hidden_states,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
)[0]
output_states += (hidden_states,)
@@ -543,7 +631,11 @@ def __init__(
self.downsamplers = nn.ModuleList(
[
Downsample2D(
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
)
]
)
@@ -553,7 +645,10 @@ def __init__(
self.gradient_checkpointing = False
def forward(
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, num_frames: int = 1
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ num_frames: int = 1,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
@@ -716,7 +811,10 @@ def forward(
return_dict=False,
)[0]
hidden_states = temp_attn(
- hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
+ hidden_states,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
)[0]
if self.upsamplers is not None:
@@ -890,7 +988,11 @@ def __init__(
self.downsamplers = nn.ModuleList(
[
Downsample2D(
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
)
]
)
@@ -920,14 +1022,20 @@ def custom_forward(*inputs):
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, scale
)
hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, num_frames
+ create_custom_forward(motion_module),
+ hidden_states.requires_grad_(),
+ temb,
+ num_frames,
)
else:
@@ -1047,7 +1155,11 @@ def __init__(
self.downsamplers = nn.ModuleList(
[
Downsample2D(
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
)
]
)
@@ -1442,7 +1554,10 @@ def custom_forward(*inputs):
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
@@ -1636,3 +1751,645 @@ def custom_forward(*inputs):
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states
+
+
+class MidBlockTemporalDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ attention_head_dim: int = 512,
+ num_layers: int = 1,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+
+ resnets = []
+ attentions = []
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=1e-6,
+ temporal_eps=1e-5,
+ merge_factor=0.0,
+ merge_strategy="learned",
+ switch_spatial_to_temporal_mix=True,
+ )
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=in_channels,
+ heads=in_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ eps=1e-6,
+ upcast_attention=upcast_attention,
+ norm_num_groups=32,
+ bias=True,
+ residual_connection=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ image_only_indicator: torch.FloatTensor,
+ ):
+ hidden_states = self.resnets[0](
+ hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ for resnet, attn in zip(self.resnets[1:], self.attentions):
+ hidden_states = attn(hidden_states)
+ hidden_states = resnet(
+ hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+
+ return hidden_states
+
+
+class UpBlockTemporalDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=1e-6,
+ temporal_eps=1e-5,
+ merge_factor=0.0,
+ merge_strategy="learned",
+ switch_spatial_to_temporal_mix=True,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ image_only_indicator: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(
+ hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class UNetMidBlockSpatioTemporal(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ # support for variable transformer layers per block
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ # there is always at least one resnet
+ resnets = [
+ SpatioTemporalResBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=1e-5,
+ )
+ ]
+ attentions = []
+
+ for i in range(num_layers):
+ attentions.append(
+ TransformerSpatioTemporalModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ )
+ )
+
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=1e-5,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ hidden_states = self.resnets[0](
+ hidden_states,
+ temb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if self.training and self.gradient_checkpointing: # TODO
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ return_dict=False,
+ )[0]
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ return_dict=False,
+ )[0]
+ hidden_states = resnet(
+ hidden_states,
+ temb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ return hidden_states
+
+
+class DownBlockSpatioTemporal(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ num_layers: int = 1,
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=1e-5,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ name="op",
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ output_states = ()
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ )
+ else:
+ hidden_states = resnet(
+ hidden_states,
+ temb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlockSpatioTemporal(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=1e-6,
+ )
+ )
+ attentions.append(
+ TransformerSpatioTemporalModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=1,
+ name="op",
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ output_states = ()
+
+ blocks = list(zip(self.resnets, self.attentions))
+ for resnet, attn in blocks:
+ if self.training and self.gradient_checkpointing: # TODO
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ **ckpt_kwargs,
+ )
+
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ return_dict=False,
+ )[0]
+ else:
+ hidden_states = resnet(
+ hidden_states,
+ temb,
+ image_only_indicator=image_only_indicator,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ return_dict=False,
+ )[0]
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class UpBlockSpatioTemporal(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ )
+ else:
+ hidden_states = resnet(
+ hidden_states,
+ temb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CrossAttnUpBlockSpatioTemporal(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ resnet_eps: float = 1e-6,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ )
+ )
+ attentions.append(
+ TransformerSpatioTemporalModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing: # TODO
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ return_dict=False,
+ )[0]
+ else:
+ hidden_states = resnet(
+ hidden_states,
+ temb,
+ image_only_indicator=image_only_indicator,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ return_dict=False,
+ )[0]
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
diff --git a/src/diffusers/models/unet_spatio_temporal_condition.py b/src/diffusers/models/unet_spatio_temporal_condition.py
new file mode 100644
index 000000000000..8d0d3e61d879
--- /dev/null
+++ b/src/diffusers/models/unet_spatio_temporal_condition.py
@@ -0,0 +1,489 @@
+from dataclasses import dataclass
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..loaders import UNet2DConditionLoadersMixin
+from ..utils import BaseOutput, logging
+from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
+from .embeddings import TimestepEmbedding, Timesteps
+from .modeling_utils import ModelMixin
+from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNetSpatioTemporalConditionOutput(BaseOutput):
+ """
+ The output of [`UNetSpatioTemporalConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ addition_time_embed_dim: (`int`, defaults to 256):
+ Dimension to to encode the additional time ids.
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
+ The dimension of the projection of encoded `added_time_ids`.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
+ The number of attention heads.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 8,
+ out_channels: int = 4,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal",
+ ),
+ up_block_types: Tuple[str] = (
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ ),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ addition_time_embed_dim: int = 256,
+ projection_class_embeddings_input_dim: int = 768,
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
+ num_frames: int = 25,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[0],
+ kernel_size=3,
+ padding=1,
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-5,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ resnet_act_fn="silu",
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlockSpatioTemporal(
+ block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=1e-5,
+ resolution_idx=i,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ resnet_act_fn="silu",
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
+ self.conv_act = nn.SiLU()
+
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0],
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(
+ name: str,
+ module: torch.nn.Module,
+ processors: Dict[str, AttentionProcessor],
+ ):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ added_time_ids: torch.Tensor,
+ return_dict: bool = True,
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
+ r"""
+ The [`UNetSpatioTemporalConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
+ added_time_ids: (`torch.FloatTensor`):
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
+ embeddings and added to the time embeddings.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
+ tuple.
+ Returns:
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ batch_size, num_frames = sample.shape[:2]
+ timesteps = timesteps.expand(batch_size)
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb)
+
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+ time_embeds = time_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(time_embeds)
+ emb = emb + aug_emb
+
+ # Flatten the batch and frames dimensions
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
+ sample = sample.flatten(0, 1)
+ # Repeat the embeddings num_video_frames times
+ # emb: [batch, channels] -> [batch * frames, channels]
+ emb = emb.repeat_interleave(num_frames, dim=0)
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ # 7. Reshape back to original shape
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
+
+ if not return_dict:
+ return (sample,)
+
+ return UNetSpatioTemporalConditionOutput(sample=sample)
diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py
index 0f849a66eaea..0049456e2187 100644
--- a/src/diffusers/models/vae.py
+++ b/src/diffusers/models/vae.py
@@ -22,7 +22,12 @@
from ..utils.torch_utils import randn_tensor
from .activations import get_activation
from .attention_processor import SpatialNorm
-from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block
+from .unet_2d_blocks import (
+ AutoencoderTinyBlock,
+ UNetMidBlock2D,
+ get_down_block,
+ get_up_block,
+)
@dataclass
@@ -274,7 +279,9 @@ def __init__(
self.gradient_checkpointing = False
def forward(
- self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None
+ self,
+ sample: torch.FloatTensor,
+ latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
@@ -292,14 +299,20 @@ def custom_forward(*inputs):
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
+ create_custom_forward(self.mid_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
+ create_custom_forward(up_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
)
else:
# middle
@@ -540,7 +553,10 @@ def custom_forward(*inputs):
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
+ create_custom_forward(self.mid_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
)
sample = sample.to(upscale_dtype)
@@ -548,7 +564,10 @@ def custom_forward(*inputs):
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.condition_encoder), masked_image, mask, use_reentrant=False
+ create_custom_forward(self.condition_encoder),
+ masked_image,
+ mask,
+ use_reentrant=False,
)
# up
@@ -558,7 +577,10 @@ def custom_forward(*inputs):
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
+ create_custom_forward(up_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
@@ -573,7 +595,9 @@ def custom_forward(*inputs):
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.condition_encoder), masked_image, mask
+ create_custom_forward(self.condition_encoder),
+ masked_image,
+ mask,
)
# up
@@ -754,7 +778,10 @@ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
- self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
+ self.mean.shape,
+ generator=generator,
+ device=self.parameters.device,
+ dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
@@ -764,7 +791,10 @@ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
return torch.Tensor([0.0])
else:
if other is None:
- return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
@@ -779,7 +809,10 @@ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
- return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
def mode(self) -> torch.Tensor:
return self.mean
@@ -820,7 +853,16 @@ def __init__(
if i == 0:
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
else:
- layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False))
+ layers.append(
+ nn.Conv2d(
+ num_channels,
+ num_channels,
+ kernel_size=3,
+ padding=1,
+ stride=2,
+ bias=False,
+ )
+ )
for _ in range(num_block):
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
@@ -899,7 +941,15 @@ def __init__(
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
conv_out_channel = num_channels if not is_final_block else out_channels
- layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block))
+ layers.append(
+ nn.Conv2d(
+ num_channels,
+ conv_out_channel,
+ kernel_size=3,
+ padding=1,
+ bias=is_final_block,
+ )
+ )
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 78c1b7c6285d..5bb6a301ca4a 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -17,7 +17,12 @@
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
-_import_structure = {"stable_diffusion": [], "stable_diffusion_xl": [], "latent_diffusion": [], "controlnet": []}
+_import_structure = {
+ "controlnet": [],
+ "latent_diffusion": [],
+ "stable_diffusion": [],
+ "stable_diffusion_xl": [],
+}
try:
if not is_torch_available():
@@ -39,7 +44,11 @@
_import_structure["dit"] = ["DiTPipeline"]
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
- _import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"]
+ _import_structure["pipeline_utils"] = [
+ "AudioPipelineOutput",
+ "DiffusionPipeline",
+ "ImagePipelineOutput",
+ ]
_import_structure["pndm"] = ["PNDMPipeline"]
_import_structure["repaint"] = ["RePaintPipeline"]
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
@@ -61,7 +70,10 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
- _import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"]
+ _import_structure["alt_diffusion"] = [
+ "AltDiffusionImg2ImgPipeline",
+ "AltDiffusionPipeline",
+ ]
_import_structure["animatediff"] = ["AnimateDiffPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
@@ -110,7 +122,10 @@
"KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline",
]
- _import_structure["kandinsky3"] = ["Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline"]
+ _import_structure["kandinsky3"] = [
+ "Kandinsky3Img2ImgPipeline",
+ "Kandinsky3Pipeline",
+ ]
_import_structure["latent_consistency_models"] = [
"LatentConsistencyModelImg2ImgPipeline",
"LatentConsistencyModelPipeline",
@@ -150,6 +165,7 @@
]
)
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
+ _import_structure["stable_video_diffusion"] = ["StableVideoDiffusionPipeline"]
_import_structure["stable_diffusion_xl"].extend(
[
"StableDiffusionXLImg2ImgPipeline",
@@ -158,10 +174,14 @@
"StableDiffusionXLPipeline",
]
)
- _import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
+ _import_structure["t2i_adapter"] = [
+ "StableDiffusionAdapterPipeline",
+ "StableDiffusionXLAdapterPipeline",
+ ]
_import_structure["text_to_video_synthesis"] = [
"TextToVideoSDPipeline",
"TextToVideoZeroPipeline",
+ "TextToVideoZeroSDXLPipeline",
"VideoToVideoSDPipeline",
]
_import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
@@ -215,7 +235,9 @@
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
- from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
+ from ..utils import (
+ dummy_torch_and_transformers_and_k_diffusion_objects,
+ )
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else:
@@ -258,7 +280,10 @@
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
else:
- _import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
+ _import_structure["spectrogram_diffusion"] = [
+ "MidiProcessor",
+ "SpectrogramDiffusionPipeline",
+ ]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -268,7 +293,11 @@
from ..utils.dummy_pt_objects import * # noqa F403
else:
- from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
+ from .auto_pipeline import (
+ AutoPipelineForImage2Image,
+ AutoPipelineForInpainting,
+ AutoPipelineForText2Image,
+ )
from .consistency_models import ConsistencyModelPipeline
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
@@ -276,7 +305,11 @@
from .dit import DiTPipeline
from .latent_diffusion import LDMSuperResolutionPipeline
from .latent_diffusion_uncond import LDMPipeline
- from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
+ from .pipeline_utils import (
+ AudioPipelineOutput,
+ DiffusionPipeline,
+ ImagePipelineOutput,
+ )
from .pndm import PNDMPipeline
from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline
@@ -299,7 +332,11 @@
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .animatediff import AnimateDiffPipeline
from .audioldm import AudioLDMPipeline
- from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
+ from .audioldm2 import (
+ AudioLDM2Pipeline,
+ AudioLDM2ProjectionModel,
+ AudioLDM2UNet2DConditionModel,
+ )
from .blip_diffusion import BlipDiffusionPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
@@ -343,7 +380,10 @@
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
)
- from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
+ from .latent_consistency_models import (
+ LatentConsistencyModelImg2ImgPipeline,
+ LatentConsistencyModelPipeline,
+ )
from .latent_diffusion import LDMTextToImagePipeline
from .musicldm import MusicLDMPipeline
from .paint_by_example import PaintByExamplePipeline
@@ -382,10 +422,15 @@
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline,
)
- from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
+ from .stable_video_diffusion import StableVideoDiffusionPipeline
+ from .t2i_adapter import (
+ StableDiffusionAdapterPipeline,
+ StableDiffusionXLAdapterPipeline,
+ )
from .text_to_video_synthesis import (
TextToVideoSDPipeline,
TextToVideoZeroPipeline,
+ TextToVideoZeroSDXLPipeline,
VideoToVideoSDPipeline,
)
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
@@ -471,7 +516,10 @@
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
else:
- from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
+ from .spectrogram_diffusion import (
+ MidiProcessor,
+ SpectrogramDiffusionPipeline,
+ )
else:
import sys
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index 7bbc4889e7ac..72c2250dd5ac 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -21,10 +21,10 @@
import PIL.Image
import torch
import torch.nn.functional as F
-from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -241,7 +241,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
class StableDiffusionControlNetInpaintPipeline(
- DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
):
r"""
Pipeline for image inpainting using Stable Diffusion with ControlNet guidance.
@@ -251,6 +251,7 @@ class StableDiffusionControlNetInpaintPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
@@ -288,7 +289,7 @@ class StableDiffusionControlNetInpaintPipeline(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
- _optional_components = ["safety_checker", "feature_extractor"]
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
@@ -302,6 +303,7 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -334,6 +336,7 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -593,6 +596,20 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+ return image_embeds, uncond_image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -1053,6 +1070,7 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -1131,6 +1149,7 @@ def __call__(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -1264,6 +1283,11 @@ def __call__(
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ if ip_adapter_image is not None:
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+
# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
control_image = self.prepare_control_image(
@@ -1299,7 +1323,7 @@ def __call__(
else:
assert False
- # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
+ # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
@@ -1360,7 +1384,10 @@ def __call__(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 7.1 Create tensor stating which controlnets to keep
+ # 7.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
+
+ # 7.2 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
@@ -1423,6 +1450,7 @@ def __call__(
cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
+ added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 695d961a5d6f..b84344fab85e 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -1688,7 +1688,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
if module_candidate is None or not isinstance(module_candidate, str):
continue
- candidate_file = os.path.join(component, module_candidate + ".py")
+ # We compute candidate file path on the Hub. Do not use `os.path.join`.
+ candidate_file = f"{component}/{module_candidate}.py"
if candidate_file in filenames:
custom_components[component] = module_candidate
diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py
index fcdca9c9f08b..5706298a281a 100644
--- a/src/diffusers/pipelines/stable_diffusion/__init__.py
+++ b/src/diffusers/pipelines/stable_diffusion/__init__.py
@@ -55,7 +55,9 @@
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
- from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
+ from ...utils.dummy_torch_and_transformers_objects import (
+ StableDiffusionImageVariationPipeline,
+ )
_dummy_objects.update({"StableDiffusionImageVariationPipeline": StableDiffusionImageVariationPipeline})
else:
@@ -90,7 +92,9 @@
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
- from ...utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
+ from ...utils import (
+ dummy_torch_and_transformers_and_k_diffusion_objects,
+ )
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else:
@@ -137,18 +141,32 @@
StableDiffusionPipelineOutput,
StableDiffusionSafetyChecker,
)
- from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
+ from .pipeline_stable_diffusion_attend_and_excite import (
+ StableDiffusionAttendAndExcitePipeline,
+ )
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline
- from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline
+ from .pipeline_stable_diffusion_gligen_text_image import (
+ StableDiffusionGLIGENTextImagePipeline,
+ )
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
- from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
- from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline
- from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline
+ from .pipeline_stable_diffusion_inpaint_legacy import (
+ StableDiffusionInpaintPipelineLegacy,
+ )
+ from .pipeline_stable_diffusion_instruct_pix2pix import (
+ StableDiffusionInstructPix2PixPipeline,
+ )
+ from .pipeline_stable_diffusion_latent_upscale import (
+ StableDiffusionLatentUpscalePipeline,
+ )
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
- from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline
+ from .pipeline_stable_diffusion_model_editing import (
+ StableDiffusionModelEditingPipeline,
+ )
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
- from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline
+ from .pipeline_stable_diffusion_paradigms import (
+ StableDiffusionParadigmsPipeline,
+ )
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .pipeline_stable_unclip import StableUnCLIPPipeline
@@ -160,9 +178,13 @@
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
- from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
+ from ...utils.dummy_torch_and_transformers_objects import (
+ StableDiffusionImageVariationPipeline,
+ )
else:
- from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
+ from .pipeline_stable_diffusion_image_variation import (
+ StableDiffusionImageVariationPipeline,
+ )
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")):
@@ -174,9 +196,13 @@
StableDiffusionPix2PixZeroPipeline,
)
else:
- from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline
+ from .pipeline_stable_diffusion_depth2img import (
+ StableDiffusionDepth2ImgPipeline,
+ )
from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
- from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline
+ from .pipeline_stable_diffusion_pix2pix_zero import (
+ StableDiffusionPix2PixZeroPipeline,
+ )
try:
if not (
@@ -189,7 +215,9 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
else:
- from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
+ from .pipeline_stable_diffusion_k_diffusion import (
+ StableDiffusionKDiffusionPipeline,
+ )
try:
if not (is_transformers_available() and is_onnx_available()):
@@ -197,11 +225,22 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_onnx_objects import *
else:
- from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
- from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
- from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline
- from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy
- from .pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline
+ from .pipeline_onnx_stable_diffusion import (
+ OnnxStableDiffusionPipeline,
+ StableDiffusionOnnxPipeline,
+ )
+ from .pipeline_onnx_stable_diffusion_img2img import (
+ OnnxStableDiffusionImg2ImgPipeline,
+ )
+ from .pipeline_onnx_stable_diffusion_inpaint import (
+ OnnxStableDiffusionInpaintPipeline,
+ )
+ from .pipeline_onnx_stable_diffusion_inpaint_legacy import (
+ OnnxStableDiffusionInpaintPipelineLegacy,
+ )
+ from .pipeline_onnx_stable_diffusion_upscale import (
+ OnnxStableDiffusionUpscalePipeline,
+ )
try:
if not (is_transformers_available() and is_flax_available()):
@@ -210,8 +249,12 @@
from ...utils.dummy_flax_objects import *
else:
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
- from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
- from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline
+ from .pipeline_flax_stable_diffusion_img2img import (
+ FlaxStableDiffusionImg2ImgPipeline,
+ )
+ from .pipeline_flax_stable_diffusion_inpaint import (
+ FlaxStableDiffusionInpaintPipeline,
+ )
from .pipeline_output import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
diff --git a/src/diffusers/pipelines/stable_video_diffusion/__init__.py b/src/diffusers/pipelines/stable_video_diffusion/__init__.py
new file mode 100644
index 000000000000..3bd4dc78966e
--- /dev/null
+++ b/src/diffusers/pipelines/stable_video_diffusion/__init__.py
@@ -0,0 +1,58 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ BaseOutput,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure.update(
+ {
+ "pipeline_stable_video_diffusion": [
+ "StableVideoDiffusionPipeline",
+ "StableVideoDiffusionPipelineOutput",
+ ],
+ }
+ )
+
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_stable_video_diffusion import (
+ StableVideoDiffusionPipeline,
+ StableVideoDiffusionPipelineOutput,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
new file mode 100644
index 000000000000..a82f5379e71a
--- /dev/null
+++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
@@ -0,0 +1,649 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
+from ...schedulers import EulerDiscreteScheduler
+from ...utils import BaseOutput, logging
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def tensor2vid(video: torch.Tensor, processor, output_type="np"):
+ # Based on:
+ # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
+
+ batch_size, channels, num_frames, height, width = video.shape
+ outputs = []
+ for batch_idx in range(batch_size):
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
+ batch_output = processor.postprocess(batch_vid, output_type)
+
+ outputs.append(batch_output)
+
+ return outputs
+
+
+@dataclass
+class StableVideoDiffusionPipelineOutput(BaseOutput):
+ r"""
+ Output class for zero-shot text-to-video pipeline.
+
+ Args:
+ frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
+ num_channels)`.
+ """
+
+ frames: Union[List[PIL.Image.Image], np.ndarray]
+
+
+class StableVideoDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline to generate video from an input image using Stable Video Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
+ Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
+ unet ([`UNetSpatioTemporalConditionModel`]):
+ A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
+ scheduler ([`EulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images.
+ """
+
+ model_cpu_offload_seq = "image_encoder->unet->vae"
+ _callback_tensor_inputs = ["latents"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKLTemporalDecoder,
+ image_encoder: CLIPVisionModelWithProjection,
+ unet: UNetSpatioTemporalConditionModel,
+ scheduler: EulerDiscreteScheduler,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ image_encoder=image_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.image_processor.pil_to_numpy(image)
+ image = self.image_processor.numpy_to_pt(image)
+
+ # We normalize the image before resizing to match with the original implementation.
+ # Then we unnormalize it after resizing.
+ image = image * 2.0 - 1.0
+ image = _resize_with_antialiasing(image, (224, 224))
+ image = (image + 1.0) / 2.0
+
+ # Normalize the image with for CLIP input
+ image = self.feature_extractor(
+ images=image,
+ do_normalize=True,
+ do_center_crop=False,
+ do_resize=False,
+ do_rescale=False,
+ return_tensors="pt",
+ ).pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeddings = self.image_encoder(image).image_embeds
+ image_embeddings = image_embeddings.unsqueeze(1)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
+
+ return image_embeddings
+
+ def _encode_vae_image(
+ self,
+ image: torch.Tensor,
+ device,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ ):
+ image = image.to(device=device)
+ image_latents = self.vae.encode(image).latent_dist.mode()
+
+ if do_classifier_free_guidance:
+ negative_image_latents = torch.zeros_like(image_latents)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_latents = torch.cat([negative_image_latents, image_latents])
+
+ # duplicate image_latents for each generation per prompt, using mps friendly method
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
+
+ return image_latents
+
+ def _get_add_time_ids(
+ self,
+ fps,
+ motion_bucket_id,
+ noise_aug_strength,
+ dtype,
+ batch_size,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ ):
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
+
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
+
+ if do_classifier_free_guidance:
+ add_time_ids = torch.cat([add_time_ids, add_time_ids])
+
+ return add_time_ids
+
+ def decode_latents(self, latents, num_frames, decode_chunk_size=14):
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
+ latents = latents.flatten(0, 1)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys())
+
+ # decode decode_chunk_size frames at a time to avoid OOM
+ frames = []
+ for i in range(0, latents.shape[0], decode_chunk_size):
+ num_frames_in = latents[i : i + decode_chunk_size].shape[0]
+ decode_kwargs = {}
+ if accepts_num_frames:
+ # we only pass num_frames_in if it's expected
+ decode_kwargs["num_frames"] = num_frames_in
+
+ frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
+ frames.append(frame)
+ frames = torch.cat(frames, dim=0)
+
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ frames = frames.float()
+ return frames
+
+ def check_inputs(self, image, height, width):
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
+ f" {type(image)}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_frames,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ shape = (
+ batch_size,
+ num_frames,
+ num_channels_latents // 2,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
+ height: int = 576,
+ width: int = 1024,
+ num_frames: Optional[int] = None,
+ num_inference_steps: int = 25,
+ min_guidance_scale: float = 1.0,
+ max_guidance_scale: float = 3.0,
+ fps: int = 7,
+ motion_bucket_id: int = 127,
+ noise_aug_strength: int = 0.02,
+ decode_chunk_size: Optional[int] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ return_dict: bool = True,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_frames (`int`, *optional*):
+ The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
+ num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
+ fps (`int`, *optional*, defaults to 7):
+ Frames per second. The rate at which the generated images shall be exported to a video after generation.
+ Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
+ motion_bucket_id (`int`, *optional*, defaults to 127):
+ The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
+ noise_aug_strength (`int`, *optional*, defaults to 0.02):
+ The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
+ decode_chunk_size (`int`, *optional*):
+ The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
+ between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
+ for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
+
+ Examples:
+
+ ```py
+ from diffusers import StableVideoDiffusionPipeline
+ from diffusers.utils import load_image, export_to_video
+
+ pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
+ pipe.to("cuda")
+
+ image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
+ image = image.resize((1024, 576))
+
+ frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
+ export_to_video(frames, "generated.mp4", fps=7)
+ ```
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(image, height, width)
+
+ # 2. Define call parameters
+ if isinstance(image, PIL.Image.Image):
+ batch_size = 1
+ elif isinstance(image, list):
+ batch_size = len(image)
+ else:
+ batch_size = image.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = max_guidance_scale > 1.0
+
+ # 3. Encode input image
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
+
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
+ # is why it is reduced here.
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
+ fps = fps - 1
+
+ # 4. Encode input image using VAE
+ image = self.image_processor.preprocess(image, height=height, width=width)
+ noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
+ image = image + noise_aug_strength * noise
+
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float32)
+
+ image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
+ image_latents = image_latents.to(image_embeddings.dtype)
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ # Repeat the image latents for each frame so we can concatenate them with the noise
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
+
+ # 5. Get Added Time IDs
+ added_time_ids = self._get_add_time_ids(
+ fps,
+ motion_bucket_id,
+ noise_aug_strength,
+ image_embeddings.dtype,
+ batch_size,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ )
+ added_time_ids = added_time_ids.to(device)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_frames,
+ num_channels_latents,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare guidance scale
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
+ guidance_scale = guidance_scale.to(device, latents.dtype)
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
+
+ self._guidance_scale = guidance_scale
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # Concatenate image_latents over channels dimention
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=image_embeddings,
+ added_time_ids=added_time_ids,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
+ else:
+ frames = latents
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return frames
+
+ return StableVideoDiffusionPipelineOutput(frames=frames)
+
+
+# resizing utils
+# TODO: clean up later
+def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
+ h, w = input.shape[-2:]
+ factors = (h / size[0], w / size[1])
+
+ # First, we have to determine sigma
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
+ sigmas = (
+ max((factors[0] - 1.0) / 2.0, 0.001),
+ max((factors[1] - 1.0) / 2.0, 0.001),
+ )
+
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
+
+ # Make sure it is odd
+ if (ks[0] % 2) == 0:
+ ks = ks[0] + 1, ks[1]
+
+ if (ks[1] % 2) == 0:
+ ks = ks[0], ks[1] + 1
+
+ input = _gaussian_blur2d(input, ks, sigmas)
+
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
+ return output
+
+
+def _compute_padding(kernel_size):
+ """Compute padding tuple."""
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
+ if len(kernel_size) < 2:
+ raise AssertionError(kernel_size)
+ computed = [k - 1 for k in kernel_size]
+
+ # for even kernels we need to do asymmetric padding :(
+ out_padding = 2 * len(kernel_size) * [0]
+
+ for i in range(len(kernel_size)):
+ computed_tmp = computed[-(i + 1)]
+
+ pad_front = computed_tmp // 2
+ pad_rear = computed_tmp - pad_front
+
+ out_padding[2 * i + 0] = pad_front
+ out_padding[2 * i + 1] = pad_rear
+
+ return out_padding
+
+
+def _filter2d(input, kernel):
+ # prepare kernel
+ b, c, h, w = input.shape
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
+
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
+
+ height, width = tmp_kernel.shape[-2:]
+
+ padding_shape: list[int] = _compute_padding([height, width])
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
+
+ # kernel and input tensor reshape to align element-wise or batch-wise params
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
+
+ # convolve the tensor with the kernel.
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
+
+ out = output.view(b, c, h, w)
+ return out
+
+
+def _gaussian(window_size: int, sigma):
+ if isinstance(sigma, float):
+ sigma = torch.tensor([[sigma]])
+
+ batch_size = sigma.shape[0]
+
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
+
+ if window_size % 2 == 0:
+ x = x + 0.5
+
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
+
+ return gauss / gauss.sum(-1, keepdim=True)
+
+
+def _gaussian_blur2d(input, kernel_size, sigma):
+ if isinstance(sigma, tuple):
+ sigma = torch.tensor([sigma], dtype=input.dtype)
+ else:
+ sigma = sigma.to(dtype=input.dtype)
+
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
+ bs = sigma.shape[0]
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
+ out_x = _filter2d(input, kernel_x[..., None, :])
+ out = _filter2d(out_x, kernel_y[..., None])
+
+ return out
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py
index 9304d5c7d818..8d8fdb92769b 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py
@@ -25,6 +25,7 @@
_import_structure["pipeline_text_to_video_synth"] = ["TextToVideoSDPipeline"]
_import_structure["pipeline_text_to_video_synth_img2img"] = ["VideoToVideoSDPipeline"]
_import_structure["pipeline_text_to_video_zero"] = ["TextToVideoZeroPipeline"]
+ _import_structure["pipeline_text_to_video_zero_sdxl"] = ["TextToVideoZeroSDXLPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -38,6 +39,7 @@
from .pipeline_text_to_video_synth import TextToVideoSDPipeline
from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline
from .pipeline_text_to_video_zero import TextToVideoZeroPipeline
+ from .pipeline_text_to_video_zero_sdxl import TextToVideoZeroSDXLPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
index 9751abec2c98..0f9ffbebdcf6 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
@@ -13,6 +13,7 @@
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import BaseOutput
+from diffusers.utils.torch_utils import randn_tensor
def rearrange_0(tensor, f):
@@ -135,7 +136,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
# Cross Frame Attention
if not is_cross_attention:
- video_length = key.size()[0] // self.batch_size
+ video_length = max(1, key.size()[0] // self.batch_size)
first_frame_index = [0] * video_length
# rearrange keys to have batch and frames in the 1st and 2nd dims respectively
@@ -339,7 +340,7 @@ def forward_loop(self, x_t0, t0, t1, generator):
x_t1:
Forward process applied to x_t0 from time t0 to t1.
"""
- eps = torch.randn(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device)
+ eps = randn_tensor(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device)
alpha_vec = torch.prod(self.scheduler.alphas[t0:t1])
x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps
return x_t1
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
new file mode 100644
index 000000000000..fd020841494c
--- /dev/null
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
@@ -0,0 +1,872 @@
+import copy
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+from torch.nn.functional import grid_sample
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import BaseOutput
+from diffusers.utils.torch_utils import randn_tensor
+
+
+# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_0
+def rearrange_0(tensor, f):
+ F, C, H, W = tensor.size()
+ tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4))
+ return tensor
+
+
+# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_1
+def rearrange_1(tensor):
+ B, C, F, H, W = tensor.size()
+ return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W))
+
+
+# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_3
+def rearrange_3(tensor, f):
+ F, D, C = tensor.size()
+ return torch.reshape(tensor, (F // f, f, D, C))
+
+
+# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_4
+def rearrange_4(tensor):
+ B, F, D, C = tensor.size()
+ return torch.reshape(tensor, (B * F, D, C))
+
+
+# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor
+class CrossFrameAttnProcessor:
+ """
+ Cross frame attention processor. Each frame attends the first frame.
+
+ Args:
+ batch_size: The number that represents actual batch size, other than the frames.
+ For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
+ 2, due to classifier-free guidance.
+ """
+
+ def __init__(self, batch_size=2):
+ self.batch_size = batch_size
+
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ query = attn.to_q(hidden_states)
+
+ is_cross_attention = encoder_hidden_states is not None
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ # Cross Frame Attention
+ if not is_cross_attention:
+ video_length = key.size()[0] // self.batch_size
+ first_frame_index = [0] * video_length
+
+ # rearrange keys to have batch and frames in the 1st and 2nd dims respectively
+ key = rearrange_3(key, video_length)
+ key = key[:, first_frame_index]
+ # rearrange values to have batch and frames in the 1st and 2nd dims respectively
+ value = rearrange_3(value, video_length)
+ value = value[:, first_frame_index]
+
+ # rearrange back to original shape
+ key = rearrange_4(key)
+ value = rearrange_4(value)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor2_0
+class CrossFrameAttnProcessor2_0:
+ """
+ Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0.
+
+ Args:
+ batch_size: The number that represents actual batch size, other than the frames.
+ For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
+ 2, due to classifier-free guidance.
+ """
+
+ def __init__(self, batch_size=2):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ self.batch_size = batch_size
+
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ inner_dim = hidden_states.shape[-1]
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+
+ is_cross_attention = encoder_hidden_states is not None
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ # Cross Frame Attention
+ if not is_cross_attention:
+ video_length = max(1, key.size()[0] // self.batch_size)
+ first_frame_index = [0] * video_length
+
+ # rearrange keys to have batch and frames in the 1st and 2nd dims respectively
+ key = rearrange_3(key, video_length)
+ key = key[:, first_frame_index]
+ # rearrange values to have batch and frames in the 1st and 2nd dims respectively
+ value = rearrange_3(value, video_length)
+ value = value[:, first_frame_index]
+
+ # rearrange back to original shape
+ key = rearrange_4(key)
+ value = rearrange_4(value)
+
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+@dataclass
+class TextToVideoSDXLPipelineOutput(BaseOutput):
+ """
+ Output class for zero-shot text-to-video pipeline.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.coords_grid
+def coords_grid(batch, ht, wd, device):
+ # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py
+ coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
+ coords = torch.stack(coords[::-1], dim=0).float()
+ return coords[None].repeat(batch, 1, 1, 1)
+
+
+# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.warp_single_latent
+def warp_single_latent(latent, reference_flow):
+ """
+ Warp latent of a single frame with given flow
+
+ Args:
+ latent: latent code of a single frame
+ reference_flow: flow which to warp the latent with
+
+ Returns:
+ warped: warped latent
+ """
+ _, _, H, W = reference_flow.size()
+ _, _, h, w = latent.size()
+ coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype)
+
+ coords_t0 = coords0 + reference_flow
+ coords_t0[:, 0] /= W
+ coords_t0[:, 1] /= H
+
+ coords_t0 = coords_t0 * 2.0 - 1.0
+ coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear")
+ coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1))
+
+ warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection")
+ return warped
+
+
+# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.create_motion_field
+def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype):
+ """
+ Create translation motion field
+
+ Args:
+ motion_field_strength_x: motion strength along x-axis
+ motion_field_strength_y: motion strength along y-axis
+ frame_ids: indexes of the frames the latents of which are being processed.
+ This is needed when we perform chunk-by-chunk inference
+ device: device
+ dtype: dtype
+
+ Returns:
+
+ """
+ seq_length = len(frame_ids)
+ reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype)
+ for fr_idx in range(seq_length):
+ reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx])
+ reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx])
+ return reference_flow
+
+
+# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.create_motion_field_and_warp_latents
+def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents):
+ """
+ Creates translation motion and warps the latents accordingly
+
+ Args:
+ motion_field_strength_x: motion strength along x-axis
+ motion_field_strength_y: motion strength along y-axis
+ frame_ids: indexes of the frames the latents of which are being processed.
+ This is needed when we perform chunk-by-chunk inference
+ latents: latent codes of frames
+
+ Returns:
+ warped_latents: warped latents
+ """
+ motion_field = create_motion_field(
+ motion_field_strength_x=motion_field_strength_x,
+ motion_field_strength_y=motion_field_strength_y,
+ frame_ids=frame_ids,
+ device=latents.device,
+ dtype=latents.dtype,
+ )
+ warped_latents = latents.clone().detach()
+ for i in range(len(warped_latents)):
+ warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None])
+ return warped_latents
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class TextToVideoZeroSDXLPipeline(StableDiffusionXLPipeline):
+ r"""
+ Pipeline for zero-shot text-to-video generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
+ add_watermarker=add_watermarker,
+ )
+ processor = (
+ CrossFrameAttnProcessor2_0(batch_size=2)
+ if hasattr(F, "scaled_dot_product_attention")
+ else CrossFrameAttnProcessor(batch_size=2)
+ )
+ self.unet.set_attn_processor(processor)
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoZeroPipeline.forward_loop
+ def forward_loop(self, x_t0, t0, t1, generator):
+ """
+ Perform DDPM forward process from time t0 to t1. This is the same as adding noise with corresponding variance.
+
+ Args:
+ x_t0:
+ Latent code at time t0.
+ t0:
+ Timestep at t0.
+ t1:
+ Timestamp at t1.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+
+ Returns:
+ x_t1:
+ Forward process applied to x_t0 from time t0 to t1.
+ """
+ eps = randn_tensor(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device)
+ alpha_vec = torch.prod(self.scheduler.alphas[t0:t1])
+ x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps
+ return x_t1
+
+ def backward_loop(
+ self,
+ latents,
+ timesteps,
+ prompt_embeds,
+ guidance_scale,
+ callback,
+ callback_steps,
+ num_warmup_steps,
+ extra_step_kwargs,
+ add_text_embeds,
+ add_time_ids,
+ cross_attention_kwargs=None,
+ guidance_rescale: float = 0.0,
+ ):
+ """
+ Perform backward process given list of time steps
+
+ Args:
+ latents:
+ Latents at time timesteps[0].
+ timesteps:
+ Time steps along which to perform backward process.
+ prompt_embeds:
+ Pre-generated text embeddings.
+ guidance_scale:
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ extra_step_kwargs:
+ Extra_step_kwargs.
+ cross_attention_kwargs:
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ num_warmup_steps:
+ number of warmup steps.
+
+ Returns:
+ latents: latents of backward process output at time timesteps[-1]
+ """
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+ num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order
+
+ with self.progress_bar(total=num_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+ return latents.clone().detach()
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ video_length: Optional[int] = 8,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ frame_ids: Optional[List[int]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ motion_field_strength_x: float = 12,
+ motion_field_strength_y: float = 12,
+ output_type: Optional[str] = "tensor",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ t0: int = 44,
+ t1: int = 47,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ video_length (`int`, *optional*, defaults to 8):
+ The number of generated video frames.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ frame_ids (`List[int]`, *optional*):
+ Indexes of the frames that are being generated. This is used when generating longer videos
+ chunk-by-chunk.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ motion_field_strength_x (`float`, *optional*, defaults to 12):
+ Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439),
+ Sect. 3.3.1.
+ motion_field_strength_y (`float`, *optional*, defaults to 12):
+ Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439),
+ Sect. 3.3.1.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `Ο` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ t0 (`int`, *optional*, defaults to 44):
+ Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the
+ [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
+ t1 (`int`, *optional*, defaults to 47):
+ Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the
+ [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
+
+ Returns:
+ [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoSDXLPipelineOutput`] or
+ `tuple`: [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoSDXLPipelineOutput`]
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ assert video_length > 0
+ if frame_ids is None:
+ frame_ids = list(range(video_length))
+ assert len(frame_ids) == video_length
+
+ assert num_videos_per_prompt == 1
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ batch_size = (
+ 1 if isinstance(prompt, str) else len(prompt) if isinstance(prompt, list) else prompt_embeds.shape[0]
+ )
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_videos_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1)
+
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # Perform the first backward process up to time T_1
+ x_1_t1 = self.backward_loop(
+ timesteps=timesteps[: -t1 - 1],
+ prompt_embeds=prompt_embeds,
+ latents=latents,
+ guidance_scale=guidance_scale,
+ callback=callback,
+ callback_steps=callback_steps,
+ extra_step_kwargs=extra_step_kwargs,
+ num_warmup_steps=num_warmup_steps,
+ add_text_embeds=add_text_embeds,
+ add_time_ids=add_time_ids,
+ )
+
+ scheduler_copy = copy.deepcopy(self.scheduler)
+
+ # Perform the second backward process up to time T_0
+ x_1_t0 = self.backward_loop(
+ timesteps=timesteps[-t1 - 1 : -t0 - 1],
+ prompt_embeds=prompt_embeds,
+ latents=x_1_t1,
+ guidance_scale=guidance_scale,
+ callback=callback,
+ callback_steps=callback_steps,
+ extra_step_kwargs=extra_step_kwargs,
+ num_warmup_steps=0,
+ add_text_embeds=add_text_embeds,
+ add_time_ids=add_time_ids,
+ )
+
+ # Propagate first frame latents at time T_0 to remaining frames
+ x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1)
+
+ # Add motion in latents at time T_0
+ x_2k_t0 = create_motion_field_and_warp_latents(
+ motion_field_strength_x=motion_field_strength_x,
+ motion_field_strength_y=motion_field_strength_y,
+ latents=x_2k_t0,
+ frame_ids=frame_ids[1:],
+ )
+
+ # Perform forward process up to time T_1
+ x_2k_t1 = self.forward_loop(
+ x_t0=x_2k_t0,
+ t0=timesteps[-t0 - 1].to(torch.long),
+ t1=timesteps[-t1 - 1].to(torch.long),
+ generator=generator,
+ )
+
+ # Perform backward process from time T_1 to 0
+ latents = torch.cat([x_1_t1, x_2k_t1])
+
+ self.scheduler = scheduler_copy
+ timesteps = timesteps[-t1 - 1 :]
+
+ b, l, d = prompt_embeds.size()
+ prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d)
+
+ b, k = add_text_embeds.size()
+ add_text_embeds = add_text_embeds[:, None].repeat(1, video_length, 1).reshape(b * video_length, k)
+
+ b, k = add_time_ids.size()
+ add_time_ids = add_time_ids[:, None].repeat(1, video_length, 1).reshape(b * video_length, k)
+
+ # 7.1 Apply denoising_end
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ x_1k_0 = self.backward_loop(
+ timesteps=timesteps,
+ prompt_embeds=prompt_embeds,
+ latents=latents,
+ guidance_scale=guidance_scale,
+ callback=callback,
+ callback_steps=callback_steps,
+ extra_step_kwargs=extra_step_kwargs,
+ num_warmup_steps=0,
+ add_text_embeds=add_text_embeds,
+ add_time_ids=add_time_ids,
+ )
+
+ latents = x_1k_0
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+ return TextToVideoSDXLPipelineOutput(images=image)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return TextToVideoSDXLPipelineOutput(images=image)
diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py
index a99135300d92..6aa994676577 100644
--- a/src/diffusers/schedulers/scheduling_deis_multistep.py
+++ b/src/diffusers/schedulers/scheduling_deis_multistep.py
@@ -323,8 +323,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- sigma_min: float = in_sigmas[-1].item()
- sigma_max: float = in_sigmas[0].item()
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
index b427f19e9e03..4b638547b38a 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -358,8 +358,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- sigma_min: float = in_sigmas[-1].item()
- sigma_max: float = in_sigmas[0].item()
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
index bc8ee24a901c..e762c0ec8bba 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
@@ -358,8 +358,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- sigma_min: float = in_sigmas[-1].item()
- sigma_max: float = in_sigmas[0].item()
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
index 6fd4d3bbf7b6..2c0be3b842cc 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -357,8 +357,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- sigma_min: float = in_sigmas[-1].item()
- sigma_max: float = in_sigmas[0].item()
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py
index 59d9af9f55b6..53dc2ae15432 100644
--- a/src/diffusers/schedulers/scheduling_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_discrete.py
@@ -144,7 +144,10 @@ def __init__(
prediction_type: str = "epsilon",
interpolation_type: str = "linear",
use_karras_sigmas: Optional[bool] = False,
+ sigma_min: Optional[float] = None,
+ sigma_max: Optional[float] = None,
timestep_spacing: str = "linspace",
+ timestep_type: str = "discrete", # can be "discrete" or "continuous"
steps_offset: int = 0,
):
if trained_betas is not None:
@@ -164,13 +167,22 @@ def __init__(
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
- sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
- self.sigmas = torch.from_numpy(sigmas)
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
+
+ sigmas = torch.from_numpy(sigmas[::-1].copy()).to(dtype=torch.float32)
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
# setable values
self.num_inference_steps = None
- timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
- self.timesteps = torch.from_numpy(timesteps)
+
+ # TODO: Support the full EDM scalings for all prediction types and timestep types
+ if timestep_type == "continuous" and prediction_type == "v_prediction":
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
+ else:
+ self.timesteps = timesteps
+
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+
self.is_scale_input_called = False
self.use_karras_sigmas = use_karras_sigmas
@@ -268,10 +280,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
- sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
- self.sigmas = torch.from_numpy(sigmas).to(device=device)
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
- self.timesteps = torch.from_numpy(timesteps).to(device=device)
+ # TODO: Support the full EDM scalings for all prediction types and timestep types
+ if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
+ else:
+ self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
+
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None
def _sigma_to_t(self, sigma, log_sigmas):
@@ -301,8 +318,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- sigma_min: float = in_sigmas[-1].item()
- sigma_max: float = in_sigmas[0].item()
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
@@ -412,7 +441,7 @@ def step(
elif self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
elif self.config.prediction_type == "v_prediction":
- # * c_out + input * c_skip
+ # denoised = model_output * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py
index 980dbd1bf839..460299cf2ec1 100644
--- a/src/diffusers/schedulers/scheduling_heun_discrete.py
+++ b/src/diffusers/schedulers/scheduling_heun_discrete.py
@@ -303,8 +303,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- sigma_min: float = in_sigmas[-1].item()
- sigma_max: float = in_sigmas[0].item()
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
index e74dd868d835..aae5a15abca2 100644
--- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
+++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
@@ -324,8 +324,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- sigma_min: float = in_sigmas[-1].item()
- sigma_max: float = in_sigmas[0].item()
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
index ac590e5713ca..3248520aa9a5 100644
--- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
+++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
@@ -335,8 +335,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- sigma_min: float = in_sigmas[-1].item()
- sigma_max: float = in_sigmas[0].item()
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py
index a6d82de80b88..d778f37ec059 100644
--- a/src/diffusers/schedulers/scheduling_unipc_multistep.py
+++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -337,8 +337,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- sigma_min: float = in_sigmas[-1].item()
- sigma_max: float = in_sigmas[0].item()
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 360727ab2fc5..c19b15f2f483 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -32,6 +32,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AutoencoderTiny(metaclass=DummyObject):
_backends = ["torch"]
@@ -272,6 +287,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class UNetSpatioTemporalConditionModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class VQModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 3386a95eb7d4..b039cdc72ab6 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -1172,6 +1172,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class StableVideoDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class TextToVideoSDPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1202,6 +1217,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class TextToVideoZeroSDXLPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class UnCLIPImageVariationPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/export_utils.py b/src/diffusers/utils/export_utils.py
index f7744f9d63eb..45aece18b8fd 100644
--- a/src/diffusers/utils/export_utils.py
+++ b/src/diffusers/utils/export_utils.py
@@ -3,7 +3,7 @@
import struct
import tempfile
from contextlib import contextmanager
-from typing import List
+from typing import List, Union
import numpy as np
import PIL.Image
@@ -115,7 +115,9 @@ def export_to_obj(mesh, output_obj_path: str = None):
f.writelines("\n".join(combined_data))
-def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
+def export_to_video(
+ video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
+) -> str:
if is_opencv_available():
import cv2
else:
@@ -123,9 +125,12 @@ def export_to_video(video_frames: List[np.ndarray], output_video_path: str = Non
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
+ if isinstance(video_frames[0], PIL.Image.Image):
+ video_frames = [np.array(frame) for frame in video_frames]
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
h, w, c = video_frames[0].shape
- video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h))
+ video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
for i in range(len(video_frames)):
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
video_writer.write(img)
diff --git a/tests/models/test_models_unet_spatiotemporal.py b/tests/models/test_models_unet_spatiotemporal.py
new file mode 100644
index 000000000000..fa07eaa736ba
--- /dev/null
+++ b/tests/models/test_models_unet_spatiotemporal.py
@@ -0,0 +1,289 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import unittest
+
+import torch
+
+from diffusers import UNetSpatioTemporalConditionModel
+from diffusers.utils import logging
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_all_close,
+ torch_device,
+)
+
+from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+logger = logging.get_logger(__name__)
+
+enable_full_determinism()
+
+
+class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = UNetSpatioTemporalConditionModel
+ main_input_name = "sample"
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_frames = 2
+ num_channels = 4
+ sizes = (32, 32)
+
+ noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
+ time_step = torch.tensor([10]).to(torch_device)
+ encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
+
+ return {
+ "sample": noise,
+ "timestep": time_step,
+ "encoder_hidden_states": encoder_hidden_states,
+ "added_time_ids": self._get_add_time_ids(),
+ }
+
+ @property
+ def input_shape(self):
+ return (2, 2, 4, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (4, 32, 32)
+
+ @property
+ def fps(self):
+ return 6
+
+ @property
+ def motion_bucket_id(self):
+ return 127
+
+ @property
+ def noise_aug_strength(self):
+ return 0.02
+
+ @property
+ def addition_time_embed_dim(self):
+ return 32
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "block_out_channels": (32, 64),
+ "down_block_types": (
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal",
+ ),
+ "up_block_types": (
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ ),
+ "cross_attention_dim": 32,
+ "num_attention_heads": 8,
+ "out_channels": 4,
+ "in_channels": 4,
+ "layers_per_block": 2,
+ "sample_size": 32,
+ "projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3,
+ "addition_time_embed_dim": self.addition_time_embed_dim,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def _get_add_time_ids(self, do_classifier_free_guidance=True):
+ add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength]
+
+ passed_add_embed_dim = self.addition_time_embed_dim * len(add_time_ids)
+ expected_add_embed_dim = self.addition_time_embed_dim * 3
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], device=torch_device)
+ add_time_ids = add_time_ids.repeat(1, 1)
+ if do_classifier_free_guidance:
+ add_time_ids = torch.cat([add_time_ids, add_time_ids])
+
+ return add_time_ids
+
+ @unittest.skip("Number of Norm Groups is not configurable")
+ def test_forward_with_norm_groups(self):
+ pass
+
+ @unittest.skip("Deprecated functionality")
+ def test_model_attention_slicing(self):
+ pass
+
+ @unittest.skip("Not supported")
+ def test_model_with_use_linear_projection(self):
+ pass
+
+ @unittest.skip("Not supported")
+ def test_model_with_simple_projection(self):
+ pass
+
+ @unittest.skip("Not supported")
+ def test_model_with_class_embeddings_concat(self):
+ pass
+
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_enable_works(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+
+ model.enable_xformers_memory_efficient_attention()
+
+ assert (
+ model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
+ == "XFormersAttnProcessor"
+ ), "xformers is not enabled"
+
+ @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
+ def test_gradient_checkpointing(self):
+ # enable deterministic behavior for gradient checkpointing
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+
+ assert not model.is_gradient_checkpointing and model.training
+
+ out = model(**inputs_dict).sample
+ # run the backwards pass on the model. For backwards pass, for simplicity purpose,
+ # we won't calculate the loss and rather backprop on out.sum()
+ model.zero_grad()
+
+ labels = torch.randn_like(out)
+ loss = (out - labels).mean()
+ loss.backward()
+
+ # re-instantiate the model now enabling gradient checkpointing
+ model_2 = self.model_class(**init_dict)
+ # clone model
+ model_2.load_state_dict(model.state_dict())
+ model_2.to(torch_device)
+ model_2.enable_gradient_checkpointing()
+
+ assert model_2.is_gradient_checkpointing and model_2.training
+
+ out_2 = model_2(**inputs_dict).sample
+ # run the backwards pass on the model. For backwards pass, for simplicity purpose,
+ # we won't calculate the loss and rather backprop on out.sum()
+ model_2.zero_grad()
+ loss_2 = (out_2 - labels).mean()
+ loss_2.backward()
+
+ # compare the output and parameters gradients
+ self.assertTrue((loss - loss_2).abs() < 1e-5)
+ named_params = dict(model.named_parameters())
+ named_params_2 = dict(model_2.named_parameters())
+ for name, param in named_params.items():
+ self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
+
+ def test_model_with_num_attention_heads_tuple(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ init_dict["num_attention_heads"] = (8, 16)
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.sample
+
+ self.assertIsNotNone(output)
+ expected_shape = inputs_dict["sample"].shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+
+ def test_model_with_cross_attention_dim_tuple(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ init_dict["cross_attention_dim"] = (32, 32)
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.sample
+
+ self.assertIsNotNone(output)
+ expected_shape = inputs_dict["sample"].shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+
+ def test_gradient_checkpointing_is_applied(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ init_dict["num_attention_heads"] = (8, 16)
+
+ model_class_copy = copy.copy(self.model_class)
+
+ modules_with_gc_enabled = {}
+
+ # now monkey patch the following function:
+ # def _set_gradient_checkpointing(self, module, value=False):
+ # if hasattr(module, "gradient_checkpointing"):
+ # module.gradient_checkpointing = value
+
+ def _set_gradient_checkpointing_new(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+ modules_with_gc_enabled[module.__class__.__name__] = True
+
+ model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
+
+ model = model_class_copy(**init_dict)
+ model.enable_gradient_checkpointing()
+
+ EXPECTED_SET = {
+ "TransformerSpatioTemporalModel",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal",
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "UNetMidBlockSpatioTemporal",
+ }
+
+ assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
+ assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
+
+ def test_pickle(self):
+ # enable deterministic behavior for gradient checkpointing
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ init_dict["num_attention_heads"] = (8, 16)
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+
+ with torch.no_grad():
+ sample = model(**inputs_dict).sample
+
+ sample_copy = copy.copy(sample)
+
+ assert (sample - sample_copy).abs().max() < 1e-4
diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py
index 83788b836a78..aa755e387b61 100644
--- a/tests/models/test_models_vae.py
+++ b/tests/models/test_models_vae.py
@@ -23,6 +23,7 @@
from diffusers import (
AsymmetricAutoencoderKL,
AutoencoderKL,
+ AutoencoderKLTemporalDecoder,
AutoencoderTiny,
ConsistencyDecoderVAE,
StableDiffusionPipeline,
@@ -248,11 +249,31 @@ def test_output_pretrained(self):
)
elif torch_device == "cpu":
expected_output_slice = torch.tensor(
- [-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026]
+ [
+ -0.1352,
+ 0.0878,
+ 0.0419,
+ -0.0818,
+ -0.1069,
+ 0.0688,
+ -0.1458,
+ -0.4446,
+ -0.0026,
+ ]
)
else:
expected_output_slice = torch.tensor(
- [-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485]
+ [
+ -0.2421,
+ 0.4642,
+ 0.2507,
+ -0.0438,
+ 0.0682,
+ 0.3160,
+ -0.2018,
+ -0.0727,
+ 0.2485,
+ ]
)
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
@@ -364,6 +385,93 @@ def test_ema_training(self):
...
+class AutoncoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLTemporalDecoder
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ @property
+ def dummy_input(self):
+ batch_size = 3
+ num_channels = 3
+ sizes = (32, 32)
+
+ image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
+ num_frames = 3
+
+ return {"sample": image, "num_frames": num_frames}
+
+ @property
+ def input_shape(self):
+ return (3, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (3, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "block_out_channels": [32, 64],
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ "latent_channels": 4,
+ "layers_per_block": 2,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_forward_signature(self):
+ pass
+
+ def test_training(self):
+ pass
+
+ @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
+ def test_gradient_checkpointing(self):
+ # enable deterministic behavior for gradient checkpointing
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+
+ assert not model.is_gradient_checkpointing and model.training
+
+ out = model(**inputs_dict).sample
+ # run the backwards pass on the model. For backwards pass, for simplicity purpose,
+ # we won't calculate the loss and rather backprop on out.sum()
+ model.zero_grad()
+
+ labels = torch.randn_like(out)
+ loss = (out - labels).mean()
+ loss.backward()
+
+ # re-instantiate the model now enabling gradient checkpointing
+ model_2 = self.model_class(**init_dict)
+ # clone model
+ model_2.load_state_dict(model.state_dict())
+ model_2.to(torch_device)
+ model_2.enable_gradient_checkpointing()
+
+ assert model_2.is_gradient_checkpointing and model_2.training
+
+ out_2 = model_2(**inputs_dict).sample
+ # run the backwards pass on the model. For backwards pass, for simplicity purpose,
+ # we won't calculate the loss and rather backprop on out.sum()
+ model_2.zero_grad()
+ loss_2 = (out_2 - labels).mean()
+ loss_2.backward()
+
+ # compare the output and parameters gradients
+ self.assertTrue((loss - loss_2).abs() < 1e-5)
+ named_params = dict(model.named_parameters())
+ named_params_2 = dict(model_2.named_parameters())
+ for name, param in named_params.items():
+ if "post_quant_conv" in name:
+ continue
+
+ self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
+
+
@slow
class AutoencoderTinyIntegrationTests(unittest.TestCase):
def tearDown(self):
@@ -609,7 +717,10 @@ def test_stable_diffusion_decode_fp16(self, seed, expected_slice):
@parameterized.expand([(13,), (16,), (27,)])
@require_torch_gpu
- @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
+ @unittest.skipIf(
+ not is_xformers_available(),
+ reason="xformers is not required when using PyTorch 2.0.",
+ )
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
model = self.get_sd_vae_model(fp16=True)
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
@@ -627,7 +738,10 @@ def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
- @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
+ @unittest.skipIf(
+ not is_xformers_available(),
+ reason="xformers is not required when using PyTorch 2.0.",
+ )
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
@@ -808,7 +922,10 @@ def test_stable_diffusion_decode(self, seed, expected_slice):
@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
- @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
+ @unittest.skipIf(
+ not is_xformers_available(),
+ reason="xformers is not required when using PyTorch 2.0.",
+ )
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
@@ -886,7 +1003,10 @@ def test_sd(self):
pipe.to(torch_device)
out = pipe(
- "horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
+ "horse",
+ num_inference_steps=2,
+ output_type="pt",
+ generator=torch.Generator("cpu").manual_seed(0),
).images[0]
actual_output = out[:2, :2, :2].flatten().cpu()
@@ -916,7 +1036,8 @@ def test_encode_decode_f16(self):
actual_output = sample[0, :2, :2, :2].flatten().cpu()
expected_output = torch.tensor(
- [-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16
+ [-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471],
+ dtype=torch.float16,
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)
@@ -926,17 +1047,24 @@ def test_sd_f16(self):
"openai/consistency-decoder", torch_dtype=torch.float16
) # TODO - update
pipe = StableDiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ vae=vae,
+ safety_checker=None,
)
pipe.to(torch_device)
out = pipe(
- "horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
+ "horse",
+ num_inference_steps=2,
+ output_type="pt",
+ generator=torch.Generator("cpu").manual_seed(0),
).images[0]
actual_output = out[:2, :2, :2].flatten().cpu()
expected_output = torch.tensor(
- [0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16
+ [0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035],
+ dtype=torch.float16,
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)
diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py
index a9140f3d5a31..7c3371c197d4 100644
--- a/tests/pipelines/controlnet/test_controlnet_inpaint.py
+++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py
@@ -132,6 +132,7 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
+ "image_encoder": None,
}
return components
@@ -248,6 +249,7 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
+ "image_encoder": None,
}
return components
@@ -342,6 +344,7 @@ def init_weights(m):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
+ "image_encoder": None,
}
return components
diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py
index 88d2df1ec0f8..ba129e763c22 100644
--- a/tests/pipelines/controlnet/test_controlnet_sdxl.py
+++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py
@@ -28,6 +28,7 @@
StableDiffusionXLControlNetPipeline,
UNet2DConditionModel,
)
+from diffusers.models.unet_2d_blocks import UNetMidBlock2D
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device
@@ -817,3 +818,162 @@ def test_depth(self):
original_image = images[0, -3:, -3:, -1].flatten()
expected_image = np.array([0.4399, 0.5112, 0.5478, 0.4314, 0.472, 0.4823, 0.4647, 0.4957, 0.4853])
assert np.allclose(original_image, expected_image, atol=1e-04)
+
+
+class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNetPipelineFastTests):
+ def test_controlnet_sdxl_guess(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+
+ sd_pipe = self.pipeline_class(**components)
+ sd_pipe = sd_pipe.to(device)
+
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["guess_mode"] = True
+
+ output = sd_pipe(**inputs)
+ image_slice = output.images[0, -3:, -3:, -1]
+ expected_slice = np.array(
+ [0.6831671, 0.5702532, 0.5459845, 0.6299793, 0.58563006, 0.6033695, 0.4493941, 0.46132287, 0.5035841]
+ )
+
+ # make sure that it's equal
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4
+
+ def test_controlnet_sdxl_lcm(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+
+ components = self.get_dummy_components(time_cond_proj_dim=256)
+ sd_pipe = StableDiffusionXLControlNetPipeline(**components)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ output = sd_pipe(**inputs)
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+ expected_slice = np.array([0.6850, 0.5135, 0.5545, 0.7033, 0.6617, 0.5971, 0.4165, 0.5480, 0.5070])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_conditioning_channels(self):
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ mid_block_type="UNetMidBlock2D",
+ # SD2-specific config below
+ attention_head_dim=(2, 4),
+ use_linear_projection=True,
+ addition_embed_type="text_time",
+ addition_time_embed_dim=8,
+ transformer_layers_per_block=(1, 2),
+ projection_class_embeddings_input_dim=80, # 6 * 8 + 32
+ cross_attention_dim=64,
+ time_cond_proj_dim=None,
+ )
+
+ controlnet = ControlNetModel.from_unet(unet, conditioning_channels=4)
+ assert type(controlnet.mid_block) == UNetMidBlock2D
+ assert controlnet.conditioning_channels == 4
+
+ def get_dummy_components(self, time_cond_proj_dim=None):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ mid_block_type="UNetMidBlock2D",
+ # SD2-specific config below
+ attention_head_dim=(2, 4),
+ use_linear_projection=True,
+ addition_embed_type="text_time",
+ addition_time_embed_dim=8,
+ transformer_layers_per_block=(1, 2),
+ projection_class_embeddings_input_dim=80, # 6 * 8 + 32
+ cross_attention_dim=64,
+ time_cond_proj_dim=time_cond_proj_dim,
+ )
+ torch.manual_seed(0)
+ controlnet = ControlNetModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ in_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ conditioning_embedding_out_channels=(16, 32),
+ mid_block_type="UNetMidBlock2D",
+ # SD2-specific config below
+ attention_head_dim=(2, 4),
+ use_linear_projection=True,
+ addition_embed_type="text_time",
+ addition_time_embed_dim=8,
+ transformer_layers_per_block=(1, 2),
+ projection_class_embeddings_input_dim=80, # 6 * 8 + 32
+ cross_attention_dim=64,
+ )
+ torch.manual_seed(0)
+ scheduler = EulerDiscreteScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ steps_offset=1,
+ beta_schedule="scaled_linear",
+ timestep_spacing="leading",
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ # SD2-specific config below
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "controlnet": controlnet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer_2": tokenizer_2,
+ "feature_extractor": None,
+ "image_encoder": None,
+ }
+ return components
diff --git a/tests/pipelines/stable_video_diffusion/__init__.py b/tests/pipelines/stable_video_diffusion/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py
new file mode 100644
index 000000000000..11978424368f
--- /dev/null
+++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py
@@ -0,0 +1,523 @@
+import gc
+import random
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPVisionConfig,
+ CLIPVisionModelWithProjection,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKLTemporalDecoder,
+ EulerDiscreteScheduler,
+ StableVideoDiffusionPipeline,
+ UNetSpatioTemporalConditionModel,
+)
+from diffusers.utils import is_accelerate_available, is_accelerate_version, load_image, logging
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.testing_utils import (
+ CaptureLogger,
+ disable_full_determinism,
+ enable_full_determinism,
+ floats_tensor,
+ numpy_cosine_similarity_distance,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+def to_np(tensor):
+ if isinstance(tensor, torch.Tensor):
+ tensor = tensor.detach().cpu().numpy()
+
+ return tensor
+
+
+class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = StableVideoDiffusionPipeline
+ params = frozenset(["image"])
+ batch_params = frozenset(["image", "generator"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ ]
+ )
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ unet = UNetSpatioTemporalConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=8,
+ out_channels=4,
+ down_block_types=(
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal",
+ ),
+ up_block_types=("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal"),
+ cross_attention_dim=32,
+ num_attention_heads=8,
+ projection_class_embeddings_input_dim=96,
+ addition_time_embed_dim=32,
+ )
+ scheduler = EulerDiscreteScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ interpolation_type="linear",
+ num_train_timesteps=1000,
+ prediction_type="v_prediction",
+ sigma_max=700.0,
+ sigma_min=0.002,
+ steps_offset=1,
+ timestep_spacing="leading",
+ timestep_type="continuous",
+ trained_betas=None,
+ use_karras_sigmas=True,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLTemporalDecoder(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ latent_channels=4,
+ )
+
+ torch.manual_seed(0)
+ config = CLIPVisionConfig(
+ hidden_size=32,
+ projection_dim=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ image_size=32,
+ intermediate_size=37,
+ patch_size=1,
+ )
+ image_encoder = CLIPVisionModelWithProjection(config)
+
+ torch.manual_seed(0)
+ feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
+ components = {
+ "unet": unet,
+ "image_encoder": image_encoder,
+ "scheduler": scheduler,
+ "vae": vae,
+ "feature_extractor": feature_extractor,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(0)).to(device)
+ inputs = {
+ "generator": generator,
+ "image": image,
+ "num_inference_steps": 2,
+ "output_type": "pt",
+ "min_guidance_scale": 1.0,
+ "max_guidance_scale": 2.5,
+ "num_frames": 2,
+ "height": 32,
+ "width": 32,
+ }
+ return inputs
+
+ @unittest.skip("Deprecated functionality")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("Batched inference works and outputs look correct, but the test is failing")
+ def test_inference_batch_single_identical(
+ self,
+ batch_size=2,
+ expected_max_diff=1e-4,
+ ):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for components in pipe.components.values():
+ if hasattr(components, "set_default_attn_processor"):
+ components.set_default_attn_processor()
+ pipe.to(torch_device)
+
+ pipe.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Reset generator in case it is has been used in self.get_dummy_inputs
+ inputs["generator"] = torch.Generator("cpu").manual_seed(0)
+
+ logger = logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # batchify inputs
+ batched_inputs = {}
+ batched_inputs.update(inputs)
+
+ batched_inputs["generator"] = [torch.Generator("cpu").manual_seed(0) for i in range(batch_size)]
+ batched_inputs["image"] = torch.cat([inputs["image"]] * batch_size, dim=0)
+
+ output = pipe(**inputs).frames
+ output_batch = pipe(**batched_inputs).frames
+
+ assert len(output_batch) == batch_size
+
+ max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
+ assert max_diff < expected_max_diff
+
+ @unittest.skip("Test is similar to test_inference_batch_single_identical")
+ def test_inference_batch_consistent(self):
+ pass
+
+ def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ output = pipe(**self.get_dummy_inputs(generator_device)).frames[0]
+ output_tuple = pipe(**self.get_dummy_inputs(generator_device), return_dict=False)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ @unittest.skip("Test is currently failing")
+ def test_float16_inference(self, expected_max_diff=5e-2):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ components = self.get_dummy_components()
+ pipe_fp16 = self.pipeline_class(**components)
+ for component in pipe_fp16.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ pipe_fp16.to(torch_device, torch.float16)
+ pipe_fp16.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs).frames[0]
+
+ fp16_inputs = self.get_dummy_inputs(torch_device)
+ output_fp16 = pipe_fp16(**fp16_inputs).frames[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
+ self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
+
+ @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ def test_save_load_float16(self, expected_max_diff=1e-2):
+ components = self.get_dummy_components()
+ for name, module in components.items():
+ if hasattr(module, "half"):
+ components[name] = module.to(torch_device).half()
+
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs).frames[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for name, component in pipe_loaded.components.items():
+ if hasattr(component, "dtype"):
+ self.assertTrue(
+ component.dtype == torch.float16,
+ f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_loaded = pipe_loaded(**inputs).frames[0]
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(
+ max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
+ )
+
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ if not hasattr(self.pipeline_class, "_optional_components"):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ # set all optional components to None
+ for optional_component in pipe._optional_components:
+ setattr(pipe, optional_component, None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output = pipe(**inputs).frames[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for optional_component in pipe._optional_components:
+ self.assertTrue(
+ getattr(pipe_loaded, optional_component) is None,
+ f"`{optional_component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(generator_device)
+ output_loaded = pipe_loaded(**inputs).frames[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ def test_save_load_local(self, expected_max_difference=9e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs).frames[0]
+
+ logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
+ logger.setLevel(diffusers.logging.INFO)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+
+ with CaptureLogger(logger) as cap_logger:
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+
+ for name in pipe_loaded.components.keys():
+ if name not in pipe_loaded._optional_components:
+ assert name in str(cap_logger)
+
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_loaded = pipe_loaded(**inputs).frames[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ def test_to_device(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.to("cpu")
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cpu" for device in model_devices))
+
+ output_cpu = pipe(**self.get_dummy_inputs("cpu")).frames[0]
+ self.assertTrue(np.isnan(output_cpu).sum() == 0)
+
+ pipe.to("cuda")
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cuda" for device in model_devices))
+
+ output_cuda = pipe(**self.get_dummy_inputs("cuda")).frames[0]
+ self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+
+ def test_to_dtype(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
+
+ pipe.to(torch_dtype=torch.float16)
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
+
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
+ reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
+ )
+ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_offload = pipe(**inputs).frames[0]
+
+ pipe.enable_sequential_cpu_offload()
+
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_offload = pipe(**inputs).frames[0]
+
+ max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
+ self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
+
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
+ reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
+ )
+ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_offload = pipe(**inputs).frames[0]
+
+ pipe.enable_model_cpu_offload()
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_offload = pipe(**inputs).frames[0]
+
+ max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
+ self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
+ offloaded_modules = [
+ v
+ for k, v in pipe.components.items()
+ if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
+ ]
+ (
+ self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
+ f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
+ )
+
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_attention_forwardGenerator_pass(self):
+ disable_full_determinism()
+
+ expected_max_diff = 9e-4
+
+ if not self.test_xformers_attention:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_without_offload = pipe(**inputs).frames[0]
+ output_without_offload = (
+ output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
+ )
+
+ pipe.enable_xformers_memory_efficient_attention()
+ inputs = self.get_dummy_inputs(torch_device)
+ output_with_offload = pipe(**inputs).frames[0]
+ output_with_offload = (
+ output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
+ )
+
+ max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
+ self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results")
+
+ enable_full_determinism()
+
+
+@slow
+@require_torch_gpu
+class StableVideoDiffusionPipelineSlowTests(unittest.TestCase):
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_sd_video(self):
+ pipe = StableVideoDiffusionPipeline.from_pretrained(
+ "stabilityai/stable-video-diffusion-img2vid",
+ variant="fp16",
+ torch_dtype=torch.float16,
+ )
+ pipe = pipe.to(torch_device)
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
+ )
+
+ generator = torch.Generator("cpu").manual_seed(0)
+ num_frames = 3
+
+ output = pipe(
+ image=image,
+ num_frames=num_frames,
+ generator=generator,
+ num_inference_steps=3,
+ output_type="np",
+ )
+
+ image = output.frames[0]
+ assert image.shape == (num_frames, 576, 1024, 3)
+
+ image_slice = image[0, -3:, -3:, -1]
+ expected_slice = np.array([0.8592, 0.8645, 0.8499, 0.8722, 0.8769, 0.8421, 0.8557, 0.8528, 0.8285])
+ assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
new file mode 100644
index 000000000000..54faa9de6d62
--- /dev/null
+++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
@@ -0,0 +1,405 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import inspect
+import io
+import re
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DDIMScheduler, TextToVideoZeroSDXLPipeline, UNet2DConditionModel
+from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version
+from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+def to_np(tensor):
+ if isinstance(tensor, torch.Tensor):
+ tensor = tensor.detach().cpu().numpy()
+
+ return tensor
+
+
+class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = TextToVideoZeroSDXLPipeline
+ params = TEXT_TO_IMAGE_PARAMS
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ generator_device = "cpu"
+
+ def get_dummy_components(self, seed=0):
+ torch.manual_seed(seed)
+ unet = UNet2DConditionModel(
+ block_out_channels=(2, 4),
+ layers_per_block=2,
+ sample_size=2,
+ norm_num_groups=2,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ # SD2-specific config below
+ attention_head_dim=(2, 4),
+ use_linear_projection=True,
+ addition_embed_type="text_time",
+ addition_time_embed_dim=8,
+ transformer_layers_per_block=(1, 2),
+ projection_class_embeddings_input_dim=80, # 6 * 8 + 32
+ cross_attention_dim=64,
+ )
+ scheduler = DDIMScheduler(
+ num_train_timesteps=1000,
+ beta_start=0.0001,
+ beta_end=0.02,
+ beta_schedule="linear",
+ trained_betas=None,
+ clip_sample=True,
+ set_alpha_to_one=True,
+ steps_offset=0,
+ prediction_type="epsilon",
+ thresholding=False,
+ dynamic_thresholding_ratio=0.995,
+ clip_sample_range=1.0,
+ sample_max_value=1.0,
+ timestep_spacing="leading",
+ rescale_betas_zero_snr=False,
+ )
+ torch.manual_seed(seed)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ sample_size=128,
+ )
+ torch.manual_seed(seed)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ # SD2-specific config below
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer_2": tokenizer_2,
+ "image_encoder": None,
+ "feature_extractor": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A panda dancing in Antarctica",
+ "generator": generator,
+ "num_inference_steps": 5,
+ "t0": 1,
+ "t1": 3,
+ "height": 64,
+ "width": 64,
+ "video_length": 3,
+ "output_type": "np",
+ }
+ return inputs
+
+ def get_generator(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ return generator
+
+ def test_text_to_video_zero_sdxl(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+
+ inputs = self.get_dummy_inputs(self.generator_device)
+ result = pipe(**inputs).images
+
+ first_frame_slice = result[0, -3:, -3:, -1]
+ last_frame_slice = result[-1, -3:, -3:, 0]
+
+ expected_slice1 = np.array([0.48, 0.58, 0.53, 0.59, 0.50, 0.44, 0.60, 0.65, 0.52])
+ expected_slice2 = np.array([0.66, 0.49, 0.40, 0.70, 0.47, 0.51, 0.73, 0.65, 0.52])
+
+ assert np.abs(first_frame_slice.flatten() - expected_slice1).max() < 1e-2
+ assert np.abs(last_frame_slice.flatten() - expected_slice2).max() < 1e-2
+
+ @unittest.skip(
+ reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
+ )
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_cfg(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ if "guidance_scale" not in sig.parameters:
+ return
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(self.generator_device)
+
+ inputs["guidance_scale"] = 1.0
+ out_no_cfg = pipe(**inputs)[0]
+
+ inputs["guidance_scale"] = 7.5
+ out_cfg = pipe(**inputs)[0]
+
+ assert out_cfg.shape == out_no_cfg.shape
+
+ def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ output = pipe(**self.get_dummy_inputs(self.generator_device))[0]
+ output_tuple = pipe(**self.get_dummy_inputs(self.generator_device), return_dict=False)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ def test_float16_inference(self, expected_max_diff=5e-2):
+ components = self.get_dummy_components()
+ for name, module in components.items():
+ if hasattr(module, "half"):
+ components[name] = module.to(torch_device).half()
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ components = self.get_dummy_components()
+ pipe_fp16 = self.pipeline_class(**components)
+ pipe_fp16.to(torch_device, torch.float16)
+ pipe_fp16.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(self.generator_device)
+ # # Reset generator in case it is used inside dummy inputs
+ if "generator" in inputs:
+ inputs["generator"] = self.get_generator(self.generator_device)
+
+ output = pipe(**inputs)[0]
+
+ fp16_inputs = self.get_dummy_inputs(self.generator_device)
+ # Reset generator in case it is used inside dummy inputs
+ if "generator" in fp16_inputs:
+ fp16_inputs["generator"] = self.get_generator(self.generator_device)
+
+ output_fp16 = pipe_fp16(**fp16_inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
+ self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
+
+ @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
+
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
+ reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
+ )
+ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(self.generator_device)
+ output_without_offload = pipe(**inputs)[0]
+
+ pipe.enable_model_cpu_offload()
+ inputs = self.get_dummy_inputs(self.generator_device)
+ output_with_offload = pipe(**inputs)[0]
+
+ max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
+ self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
+
+ @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
+ def test_pipeline_call_signature(self):
+ pass
+
+ def test_progress_bar(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+
+ inputs = self.get_dummy_inputs(self.generator_device)
+ with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
+ _ = pipe(**inputs)
+ stderr = stderr.getvalue()
+ # we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
+ # so we just match "5" in "#####| 1/5 [00:01<00:00]"
+ max_steps = re.search("/(.*?) ", stderr).group(1)
+ self.assertTrue(max_steps is not None and len(max_steps) > 0)
+ self.assertTrue(
+ f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
+ )
+
+ pipe.set_progress_bar_config(disable=True)
+ with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
+ _ = pipe(**inputs)
+ self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
+
+ @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ def test_save_load_float16(self, expected_max_diff=1e-2):
+ components = self.get_dummy_components()
+ for name, module in components.items():
+ if hasattr(module, "half"):
+ components[name] = module.to(torch_device).half()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(self.generator_device)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for name, component in pipe_loaded.components.items():
+ if hasattr(component, "dtype"):
+ self.assertTrue(
+ component.dtype == torch.float16,
+ f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(self.generator_device)
+ output_loaded = pipe_loaded(**inputs)[0]
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(
+ max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
+ )
+
+ @unittest.skip(
+ reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
+ )
+ def test_save_load_local(self):
+ pass
+
+ @unittest.skip(
+ reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
+ )
+ def test_save_load_optional_components(self):
+ pass
+
+ @unittest.skip(
+ reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
+ )
+ def test_sequential_cpu_offload_forward_pass(self):
+ pass
+
+ @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ def test_to_device(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.to("cpu")
+ model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
+ self.assertTrue(all(device == "cpu" for device in model_devices))
+
+ output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu
+ self.assertTrue(np.isnan(output_cpu).sum() == 0)
+
+ pipe.to("cuda")
+ model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
+ self.assertTrue(all(device == "cuda" for device in model_devices))
+
+ output_cuda = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu
+ self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+
+ @unittest.skip(
+ reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
+ )
+ def test_xformers_attention_forwardGenerator_pass(self):
+ pass
+
+
+@nightly
+@require_torch_gpu
+class TextToVideoZeroSDXLPipelineSlowTests(unittest.TestCase):
+ def test_full_model(self):
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
+ pipe = self.pipeline_class.from_pretrained(
+ model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+ )
+ pipe.enable_model_cpu_offload()
+ pipe.enable_vae_slicing()
+
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+ generator = torch.Generator(device="cpu").manual_seed(0)
+
+ prompt = "A panda dancing in Antarctica"
+ result = pipe(prompt=prompt, generator=generator).images
+
+ first_frame_slice = result[0, -3:, -3:, -1]
+ last_frame_slice = result[-1, -3:, -3:, 0]
+
+ expected_slice1 = np.array([0.57, 0.57, 0.57, 0.57, 0.57, 0.56, 0.55, 0.56, 0.56])
+ expected_slice2 = np.array([0.54, 0.53, 0.53, 0.53, 0.53, 0.52, 0.53, 0.53, 0.53])
+
+ assert np.abs(first_frame_slice.flatten() - expected_slice1).max() < 1e-2
+ assert np.abs(last_frame_slice.flatten() - expected_slice2).max() < 1e-2
diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py
index fa885a0542eb..3249d7032bad 100644
--- a/tests/schedulers/test_scheduler_euler.py
+++ b/tests/schedulers/test_scheduler_euler.py
@@ -37,6 +37,14 @@ def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
+ def test_timestep_type(self):
+ timestep_types = ["discrete", "continuous"]
+ for timestep_type in timestep_types:
+ self.check_over_configs(timestep_type=timestep_type)
+
+ def test_karras_sigmas(self):
+ self.check_over_configs(use_karras_sigmas=True, sigma_min=0.02, sigma_max=700.0)
+
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py
index 8bc95b38cf34..08c5ad5c3a50 100755
--- a/tests/schedulers/test_schedulers.py
+++ b/tests/schedulers/test_schedulers.py
@@ -352,8 +352,8 @@ def check_over_configs(self, time_step=0, **config):
_ = scheduler.scale_model_input(sample, scaled_sigma_max)
_ = new_scheduler.scale_model_input(sample, scaled_sigma_max)
elif scheduler_class != VQDiffusionScheduler:
- _ = scheduler.scale_model_input(sample, 0)
- _ = new_scheduler.scale_model_input(sample, 0)
+ _ = scheduler.scale_model_input(sample, scheduler.timesteps[-1])
+ _ = new_scheduler.scale_model_input(sample, scheduler.timesteps[-1])
# Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):