From 43346adc1ffa9051fc71be9af33fd982ee14c383 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 9 Nov 2023 15:38:00 +0530 Subject: [PATCH 1/3] Install accelerate from PyPI in PR test runner (#5721) install acclerate from pypi --- .github/workflows/pr_tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index aaaea147f7ab..f7d9dde5258d 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -72,7 +72,7 @@ jobs: run: | apt-get update && apt-get install libsndfile1-dev libgl1 -y python -m pip install -e .[quality,test] - python -m pip install git+https://github.com/huggingface/accelerate.git + python -m pip install accelerate - name: Environment run: | @@ -115,7 +115,7 @@ jobs: run: | python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ --make-reports=tests_${{ matrix.config.report }} \ - examples/test_examples.py + examples/test_examples.py - name: Failure short reports if: ${{ failure() }} From 2fd46405cd4e845e65b102acc8849667ab508790 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Thu, 9 Nov 2023 03:21:41 -0800 Subject: [PATCH 2/3] consistency decoder (#5694) * consistency decoder * rename * Apply suggestions from code review Co-authored-by: Sayak Paul * Update src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py * uP * Apply suggestions from code review * uP * uP * uP --------- Co-authored-by: Patrick von Platen Co-authored-by: Sayak Paul --- docs/source/en/_toctree.yml | 4 + .../en/api/models/consistency_decoder_vae.md | 18 + .../en/api/schedulers/consistency_decoder.md | 9 + scripts/convert_consistency_decoder.py | 1128 +++++++++++++++++ src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/autoencoder_kl.py | 4 +- .../models/consistency_decoder_vae.py | 365 ++++++ src/diffusers/models/unet_2d.py | 4 + .../alt_diffusion/pipeline_alt_diffusion.py | 4 +- .../pipeline_alt_diffusion_img2img.py | 4 +- .../pipelines/consistency_models/__init__.py | 4 +- .../pipeline_consistency_models.py | 14 + .../controlnet/pipeline_controlnet.py | 4 +- .../controlnet/pipeline_controlnet_img2img.py | 4 +- .../controlnet/pipeline_controlnet_inpaint.py | 4 +- .../pipeline_stable_diffusion.py | 4 +- .../pipeline_stable_diffusion_img2img.py | 4 +- .../pipeline_stable_diffusion_inpaint.py | 4 +- src/diffusers/schedulers/__init__.py | 2 + .../scheduling_consistency_decoder.py | 180 +++ src/diffusers/utils/dummy_pt_objects.py | 15 + tests/models/test_modeling_common.py | 118 +- tests/models/test_models_vae.py | 175 ++- 24 files changed, 2038 insertions(+), 38 deletions(-) create mode 100644 docs/source/en/api/models/consistency_decoder_vae.md create mode 100644 docs/source/en/api/schedulers/consistency_decoder.md create mode 100644 scripts/convert_consistency_decoder.py create mode 100644 src/diffusers/models/consistency_decoder_vae.py create mode 100644 src/diffusers/schedulers/scheduling_consistency_decoder.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3626db3f7b58..efd1fac74238 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -200,6 +200,8 @@ title: AsymmetricAutoencoderKL - local: api/models/autoencoder_tiny title: Tiny AutoEncoder + - local: api/models/consistency_decoder_vae + title: ConsistencyDecoderVAE - local: api/models/transformer2d title: Transformer2D - local: api/models/transformer_temporal @@ -344,6 +346,8 @@ title: Overview - local: api/schedulers/cm_stochastic_iterative title: CMStochasticIterativeScheduler + - local: api/schedulers/consistency_decoder + title: ConsistencyDecoderScheduler - local: api/schedulers/ddim_inverse title: DDIMInverseScheduler - local: api/schedulers/ddim diff --git a/docs/source/en/api/models/consistency_decoder_vae.md b/docs/source/en/api/models/consistency_decoder_vae.md new file mode 100644 index 000000000000..b45f7fa059dc --- /dev/null +++ b/docs/source/en/api/models/consistency_decoder_vae.md @@ -0,0 +1,18 @@ +# Consistency Decoder + +Consistency decoder can be used to decode the latents from the denoising UNet in the [`StableDiffusionPipeline`]. This decoder was introduced in the [DALL-E 3 technical report](https://openai.com/dall-e-3). + +The original codebase can be found at [openai/consistencydecoder](https://github.com/openai/consistencydecoder). + + + +Inference is only supported for 2 iterations as of now. + + + +The pipeline could not have been contributed without the help of [madebyollin](https://github.com/madebyollin) and [mrsteyk](https://github.com/mrsteyk) from [this issue](https://github.com/openai/consistencydecoder/issues/1). + +## ConsistencyDecoderVAE +[[autodoc]] ConsistencyDecoderVAE + - all + - decode diff --git a/docs/source/en/api/schedulers/consistency_decoder.md b/docs/source/en/api/schedulers/consistency_decoder.md new file mode 100644 index 000000000000..6c937b913279 --- /dev/null +++ b/docs/source/en/api/schedulers/consistency_decoder.md @@ -0,0 +1,9 @@ +# ConsistencyDecoderScheduler + +This scheduler is a part of the [`ConsistencyDecoderPipeline`] and was introduced in [DALL-E 3](https://openai.com/dall-e-3). + +The original codebase can be found at [openai/consistency_models](https://github.com/openai/consistency_models). + + +## ConsistencyDecoderScheduler +[[autodoc]] schedulers.scheduling_consistency_decoder.ConsistencyDecoderScheduler \ No newline at end of file diff --git a/scripts/convert_consistency_decoder.py b/scripts/convert_consistency_decoder.py new file mode 100644 index 000000000000..8e6da07d8c6c --- /dev/null +++ b/scripts/convert_consistency_decoder.py @@ -0,0 +1,1128 @@ +import hashlib +import math +import os +import urllib +import warnings +from argparse import ArgumentParser + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import load_file as stl +from tqdm import tqdm + +from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel +from diffusers.models.embeddings import TimestepEmbedding +from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D +from diffusers.models.vae import Encoder + + +args = ArgumentParser() +args.add_argument("--save_pretrained", required=False, default=None, type=str) +args.add_argument("--test_image", required=True, type=str) +args = args.parse_args() + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """ + res = arr[timesteps].float() + dims_to_append = len(broadcast_shape) - len(res.shape) + return res[(...,) + (None,) * dims_to_append] + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L45 + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas) + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +class ConsistencyDecoder: + def __init__(self, device="cuda:0", download_root=os.path.expanduser("~/.cache/clip")): + self.n_distilled_steps = 64 + download_target = _download( + "https://openaipublic.azureedge.net/diff-vae/c9cebd3132dd9c42936d803e33424145a748843c8f716c0814838bdc8a2fe7cb/decoder.pt", + download_root, + ) + self.ckpt = torch.jit.load(download_target).to(device) + self.device = device + sigma_data = 0.5 + betas = betas_for_alpha_bar(1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2).to(device) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) + sigmas = torch.sqrt(1.0 / alphas_cumprod - 1) + self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2) + self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5 + self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5 + + @staticmethod + def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True): + with torch.no_grad(): + space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor") + rounded_timesteps = (torch.div(timesteps, space, rounding_mode="floor") + 1) * space + if truncate_start: + rounded_timesteps[rounded_timesteps == total_timesteps] -= space + else: + rounded_timesteps[rounded_timesteps == total_timesteps] -= space + rounded_timesteps[rounded_timesteps == 0] += space + return rounded_timesteps + + @staticmethod + def ldm_transform_latent(z, extra_scale_factor=1): + channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294] + channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034] + + if len(z.shape) != 4: + raise ValueError() + + z = z * 0.18215 + channels = [z[:, i] for i in range(z.shape[1])] + + channels = [extra_scale_factor * (c - channel_means[i]) / channel_stds[i] for i, c in enumerate(channels)] + return torch.stack(channels, dim=1) + + @torch.no_grad() + def __call__( + self, + features: torch.Tensor, + schedule=[1.0, 0.5], + generator=None, + ): + features = self.ldm_transform_latent(features) + ts = self.round_timesteps( + torch.arange(0, 1024), + 1024, + self.n_distilled_steps, + truncate_start=False, + ) + shape = ( + features.size(0), + 3, + 8 * features.size(2), + 8 * features.size(3), + ) + x_start = torch.zeros(shape, device=features.device, dtype=features.dtype) + schedule_timesteps = [int((1024 - 1) * s) for s in schedule] + for i in schedule_timesteps: + t = ts[i].item() + t_ = torch.tensor([t] * features.shape[0]).to(self.device) + # noise = torch.randn_like(x_start) + noise = torch.randn(x_start.shape, dtype=x_start.dtype, generator=generator).to(device=x_start.device) + x_start = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t_, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t_, x_start.shape) * noise + ) + c_in = _extract_into_tensor(self.c_in, t_, x_start.shape) + + import torch.nn.functional as F + + from diffusers import UNet2DModel + + if isinstance(self.ckpt, UNet2DModel): + input = torch.concat([c_in * x_start, F.upsample_nearest(features, scale_factor=8)], dim=1) + model_output = self.ckpt(input, t_).sample + else: + model_output = self.ckpt(c_in * x_start, t_, features=features) + + B, C = x_start.shape[:2] + model_output, _ = torch.split(model_output, C, dim=1) + pred_xstart = ( + _extract_into_tensor(self.c_out, t_, x_start.shape) * model_output + + _extract_into_tensor(self.c_skip, t_, x_start.shape) * x_start + ).clamp(-1, 1) + x_start = pred_xstart + return x_start + + +def save_image(image, name): + import numpy as np + from PIL import Image + + image = image[0].cpu().numpy() + image = (image + 1.0) * 127.5 + image = image.clip(0, 255).astype(np.uint8) + image = Image.fromarray(image.transpose(1, 2, 0)) + image.save(name) + + +def load_image(uri, size=None, center_crop=False): + import numpy as np + from PIL import Image + + image = Image.open(uri) + if center_crop: + image = image.crop( + ( + (image.width - min(image.width, image.height)) // 2, + (image.height - min(image.width, image.height)) // 2, + (image.width + min(image.width, image.height)) // 2, + (image.height + min(image.width, image.height)) // 2, + ) + ) + if size is not None: + image = image.resize(size) + image = torch.tensor(np.array(image).transpose(2, 0, 1)).unsqueeze(0).float() + image = image / 127.5 - 1.0 + return image + + +class TimestepEmbedding_(nn.Module): + def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None: + super().__init__() + self.emb = nn.Embedding(n_time, n_emb) + self.f_1 = nn.Linear(n_emb, n_out) + self.f_2 = nn.Linear(n_out, n_out) + + def forward(self, x) -> torch.Tensor: + x = self.emb(x) + x = self.f_1(x) + x = F.silu(x) + return self.f_2(x) + + +class ImageEmbedding(nn.Module): + def __init__(self, in_channels=7, out_channels=320) -> None: + super().__init__() + self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + + def forward(self, x) -> torch.Tensor: + return self.f(x) + + +class ImageUnembedding(nn.Module): + def __init__(self, in_channels=320, out_channels=6) -> None: + super().__init__() + self.gn = nn.GroupNorm(32, in_channels) + self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + + def forward(self, x) -> torch.Tensor: + return self.f(F.silu(self.gn(x))) + + +class ConvResblock(nn.Module): + def __init__(self, in_features=320, out_features=320) -> None: + super().__init__() + self.f_t = nn.Linear(1280, out_features * 2) + + self.gn_1 = nn.GroupNorm(32, in_features) + self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1) + + self.gn_2 = nn.GroupNorm(32, out_features) + self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1) + + skip_conv = in_features != out_features + self.f_s = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) if skip_conv else nn.Identity() + + def forward(self, x, t): + x_skip = x + t = self.f_t(F.silu(t)) + t = t.chunk(2, dim=1) + t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1 + t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3) + + gn_1 = F.silu(self.gn_1(x)) + f_1 = self.f_1(gn_1) + + gn_2 = self.gn_2(f_1) + + return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2)) + + +# Also ConvResblock +class Downsample(nn.Module): + def __init__(self, in_channels=320) -> None: + super().__init__() + self.f_t = nn.Linear(1280, in_channels * 2) + + self.gn_1 = nn.GroupNorm(32, in_channels) + self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + self.gn_2 = nn.GroupNorm(32, in_channels) + + self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, x, t) -> torch.Tensor: + x_skip = x + + t = self.f_t(F.silu(t)) + t_1, t_2 = t.chunk(2, dim=1) + t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 + t_2 = t_2.unsqueeze(2).unsqueeze(3) + + gn_1 = F.silu(self.gn_1(x)) + avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None) + + f_1 = self.f_1(avg_pool2d) + gn_2 = self.gn_2(f_1) + + f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) + + return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None) + + +# Also ConvResblock +class Upsample(nn.Module): + def __init__(self, in_channels=1024) -> None: + super().__init__() + self.f_t = nn.Linear(1280, in_channels * 2) + + self.gn_1 = nn.GroupNorm(32, in_channels) + self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + self.gn_2 = nn.GroupNorm(32, in_channels) + + self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, x, t) -> torch.Tensor: + x_skip = x + + t = self.f_t(F.silu(t)) + t_1, t_2 = t.chunk(2, dim=1) + t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 + t_2 = t_2.unsqueeze(2).unsqueeze(3) + + gn_1 = F.silu(self.gn_1(x)) + upsample = F.upsample_nearest(gn_1, scale_factor=2) + f_1 = self.f_1(upsample) + gn_2 = self.gn_2(f_1) + + f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) + + return f_2 + F.upsample_nearest(x_skip, scale_factor=2) + + +class ConvUNetVAE(nn.Module): + def __init__(self) -> None: + super().__init__() + self.embed_image = ImageEmbedding() + self.embed_time = TimestepEmbedding_() + + down_0 = nn.ModuleList( + [ + ConvResblock(320, 320), + ConvResblock(320, 320), + ConvResblock(320, 320), + Downsample(320), + ] + ) + down_1 = nn.ModuleList( + [ + ConvResblock(320, 640), + ConvResblock(640, 640), + ConvResblock(640, 640), + Downsample(640), + ] + ) + down_2 = nn.ModuleList( + [ + ConvResblock(640, 1024), + ConvResblock(1024, 1024), + ConvResblock(1024, 1024), + Downsample(1024), + ] + ) + down_3 = nn.ModuleList( + [ + ConvResblock(1024, 1024), + ConvResblock(1024, 1024), + ConvResblock(1024, 1024), + ] + ) + self.down = nn.ModuleList( + [ + down_0, + down_1, + down_2, + down_3, + ] + ) + + self.mid = nn.ModuleList( + [ + ConvResblock(1024, 1024), + ConvResblock(1024, 1024), + ] + ) + + up_3 = nn.ModuleList( + [ + ConvResblock(1024 * 2, 1024), + ConvResblock(1024 * 2, 1024), + ConvResblock(1024 * 2, 1024), + ConvResblock(1024 * 2, 1024), + Upsample(1024), + ] + ) + up_2 = nn.ModuleList( + [ + ConvResblock(1024 * 2, 1024), + ConvResblock(1024 * 2, 1024), + ConvResblock(1024 * 2, 1024), + ConvResblock(1024 + 640, 1024), + Upsample(1024), + ] + ) + up_1 = nn.ModuleList( + [ + ConvResblock(1024 + 640, 640), + ConvResblock(640 * 2, 640), + ConvResblock(640 * 2, 640), + ConvResblock(320 + 640, 640), + Upsample(640), + ] + ) + up_0 = nn.ModuleList( + [ + ConvResblock(320 + 640, 320), + ConvResblock(320 * 2, 320), + ConvResblock(320 * 2, 320), + ConvResblock(320 * 2, 320), + ] + ) + self.up = nn.ModuleList( + [ + up_0, + up_1, + up_2, + up_3, + ] + ) + + self.output = ImageUnembedding() + + def forward(self, x, t, features) -> torch.Tensor: + converted = hasattr(self, "converted") and self.converted + + x = torch.cat([x, F.upsample_nearest(features, scale_factor=8)], dim=1) + + if converted: + t = self.time_embedding(self.time_proj(t)) + else: + t = self.embed_time(t) + + x = self.embed_image(x) + + skips = [x] + for i, down in enumerate(self.down): + if converted and i in [0, 1, 2, 3]: + x, skips_ = down(x, t) + for skip in skips_: + skips.append(skip) + else: + for block in down: + x = block(x, t) + skips.append(x) + print(x.float().abs().sum()) + + if converted: + x = self.mid(x, t) + else: + for i in range(2): + x = self.mid[i](x, t) + print(x.float().abs().sum()) + + for i, up in enumerate(self.up[::-1]): + if converted and i in [0, 1, 2, 3]: + skip_4 = skips.pop() + skip_3 = skips.pop() + skip_2 = skips.pop() + skip_1 = skips.pop() + skips_ = (skip_1, skip_2, skip_3, skip_4) + x = up(x, skips_, t) + else: + for block in up: + if isinstance(block, ConvResblock): + x = torch.concat([x, skips.pop()], dim=1) + x = block(x, t) + + return self.output(x) + + +def rename_state_dict_key(k): + k = k.replace("blocks.", "") + for i in range(5): + k = k.replace(f"down_{i}_", f"down.{i}.") + k = k.replace(f"conv_{i}.", f"{i}.") + k = k.replace(f"up_{i}_", f"up.{i}.") + k = k.replace(f"mid_{i}", f"mid.{i}") + k = k.replace("upsamp.", "4.") + k = k.replace("downsamp.", "3.") + k = k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias") + k = k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias") + k = k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias") + k = k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias") + k = k.replace("f.w", "f.weight").replace("f.b", "f.bias") + k = k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias") + k = k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias") + k = k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias") + return k + + +def rename_state_dict(sd, embedding): + sd = {rename_state_dict_key(k): v for k, v in sd.items()} + sd["embed_time.emb.weight"] = embedding["weight"] + return sd + + +# encode with stable diffusion vae +pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +pipe.vae.cuda() + +# construct original decoder with jitted model +decoder_consistency = ConsistencyDecoder(device="cuda:0") + +# construct UNet code, overwrite the decoder with conv_unet_vae +model = ConvUNetVAE() +model.load_state_dict( + rename_state_dict( + stl("consistency_decoder.safetensors"), + stl("embedding.safetensors"), + ) +) +model = model.cuda() + +decoder_consistency.ckpt = model + +image = load_image(args.test_image, size=(256, 256), center_crop=True) +latent = pipe.vae.encode(image.half().cuda()).latent_dist.sample() + +# decode with gan +sample_gan = pipe.vae.decode(latent).sample.detach() +save_image(sample_gan, "gan.png") + +# decode with conv_unet_vae +sample_consistency_orig = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0)) +save_image(sample_consistency_orig, "con_orig.png") + + +########### conversion + +print("CONVERSION") + +print("DOWN BLOCK ONE") + +block_one_sd_orig = model.down[0].state_dict() +block_one_sd_new = {} + +for i in range(3): + block_one_sd_new[f"resnets.{i}.norm1.weight"] = block_one_sd_orig.pop(f"{i}.gn_1.weight") + block_one_sd_new[f"resnets.{i}.norm1.bias"] = block_one_sd_orig.pop(f"{i}.gn_1.bias") + block_one_sd_new[f"resnets.{i}.conv1.weight"] = block_one_sd_orig.pop(f"{i}.f_1.weight") + block_one_sd_new[f"resnets.{i}.conv1.bias"] = block_one_sd_orig.pop(f"{i}.f_1.bias") + block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_one_sd_orig.pop(f"{i}.f_t.weight") + block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_one_sd_orig.pop(f"{i}.f_t.bias") + block_one_sd_new[f"resnets.{i}.norm2.weight"] = block_one_sd_orig.pop(f"{i}.gn_2.weight") + block_one_sd_new[f"resnets.{i}.norm2.bias"] = block_one_sd_orig.pop(f"{i}.gn_2.bias") + block_one_sd_new[f"resnets.{i}.conv2.weight"] = block_one_sd_orig.pop(f"{i}.f_2.weight") + block_one_sd_new[f"resnets.{i}.conv2.bias"] = block_one_sd_orig.pop(f"{i}.f_2.bias") + +block_one_sd_new["downsamplers.0.norm1.weight"] = block_one_sd_orig.pop("3.gn_1.weight") +block_one_sd_new["downsamplers.0.norm1.bias"] = block_one_sd_orig.pop("3.gn_1.bias") +block_one_sd_new["downsamplers.0.conv1.weight"] = block_one_sd_orig.pop("3.f_1.weight") +block_one_sd_new["downsamplers.0.conv1.bias"] = block_one_sd_orig.pop("3.f_1.bias") +block_one_sd_new["downsamplers.0.time_emb_proj.weight"] = block_one_sd_orig.pop("3.f_t.weight") +block_one_sd_new["downsamplers.0.time_emb_proj.bias"] = block_one_sd_orig.pop("3.f_t.bias") +block_one_sd_new["downsamplers.0.norm2.weight"] = block_one_sd_orig.pop("3.gn_2.weight") +block_one_sd_new["downsamplers.0.norm2.bias"] = block_one_sd_orig.pop("3.gn_2.bias") +block_one_sd_new["downsamplers.0.conv2.weight"] = block_one_sd_orig.pop("3.f_2.weight") +block_one_sd_new["downsamplers.0.conv2.bias"] = block_one_sd_orig.pop("3.f_2.bias") + +assert len(block_one_sd_orig) == 0 + +block_one = ResnetDownsampleBlock2D( + in_channels=320, + out_channels=320, + temb_channels=1280, + num_layers=3, + add_downsample=True, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +block_one.load_state_dict(block_one_sd_new) + +print("DOWN BLOCK TWO") + +block_two_sd_orig = model.down[1].state_dict() +block_two_sd_new = {} + +for i in range(3): + block_two_sd_new[f"resnets.{i}.norm1.weight"] = block_two_sd_orig.pop(f"{i}.gn_1.weight") + block_two_sd_new[f"resnets.{i}.norm1.bias"] = block_two_sd_orig.pop(f"{i}.gn_1.bias") + block_two_sd_new[f"resnets.{i}.conv1.weight"] = block_two_sd_orig.pop(f"{i}.f_1.weight") + block_two_sd_new[f"resnets.{i}.conv1.bias"] = block_two_sd_orig.pop(f"{i}.f_1.bias") + block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_two_sd_orig.pop(f"{i}.f_t.weight") + block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_two_sd_orig.pop(f"{i}.f_t.bias") + block_two_sd_new[f"resnets.{i}.norm2.weight"] = block_two_sd_orig.pop(f"{i}.gn_2.weight") + block_two_sd_new[f"resnets.{i}.norm2.bias"] = block_two_sd_orig.pop(f"{i}.gn_2.bias") + block_two_sd_new[f"resnets.{i}.conv2.weight"] = block_two_sd_orig.pop(f"{i}.f_2.weight") + block_two_sd_new[f"resnets.{i}.conv2.bias"] = block_two_sd_orig.pop(f"{i}.f_2.bias") + + if i == 0: + block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_two_sd_orig.pop(f"{i}.f_s.weight") + block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_two_sd_orig.pop(f"{i}.f_s.bias") + +block_two_sd_new["downsamplers.0.norm1.weight"] = block_two_sd_orig.pop("3.gn_1.weight") +block_two_sd_new["downsamplers.0.norm1.bias"] = block_two_sd_orig.pop("3.gn_1.bias") +block_two_sd_new["downsamplers.0.conv1.weight"] = block_two_sd_orig.pop("3.f_1.weight") +block_two_sd_new["downsamplers.0.conv1.bias"] = block_two_sd_orig.pop("3.f_1.bias") +block_two_sd_new["downsamplers.0.time_emb_proj.weight"] = block_two_sd_orig.pop("3.f_t.weight") +block_two_sd_new["downsamplers.0.time_emb_proj.bias"] = block_two_sd_orig.pop("3.f_t.bias") +block_two_sd_new["downsamplers.0.norm2.weight"] = block_two_sd_orig.pop("3.gn_2.weight") +block_two_sd_new["downsamplers.0.norm2.bias"] = block_two_sd_orig.pop("3.gn_2.bias") +block_two_sd_new["downsamplers.0.conv2.weight"] = block_two_sd_orig.pop("3.f_2.weight") +block_two_sd_new["downsamplers.0.conv2.bias"] = block_two_sd_orig.pop("3.f_2.bias") + +assert len(block_two_sd_orig) == 0 + +block_two = ResnetDownsampleBlock2D( + in_channels=320, + out_channels=640, + temb_channels=1280, + num_layers=3, + add_downsample=True, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +block_two.load_state_dict(block_two_sd_new) + +print("DOWN BLOCK THREE") + +block_three_sd_orig = model.down[2].state_dict() +block_three_sd_new = {} + +for i in range(3): + block_three_sd_new[f"resnets.{i}.norm1.weight"] = block_three_sd_orig.pop(f"{i}.gn_1.weight") + block_three_sd_new[f"resnets.{i}.norm1.bias"] = block_three_sd_orig.pop(f"{i}.gn_1.bias") + block_three_sd_new[f"resnets.{i}.conv1.weight"] = block_three_sd_orig.pop(f"{i}.f_1.weight") + block_three_sd_new[f"resnets.{i}.conv1.bias"] = block_three_sd_orig.pop(f"{i}.f_1.bias") + block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_three_sd_orig.pop(f"{i}.f_t.weight") + block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_three_sd_orig.pop(f"{i}.f_t.bias") + block_three_sd_new[f"resnets.{i}.norm2.weight"] = block_three_sd_orig.pop(f"{i}.gn_2.weight") + block_three_sd_new[f"resnets.{i}.norm2.bias"] = block_three_sd_orig.pop(f"{i}.gn_2.bias") + block_three_sd_new[f"resnets.{i}.conv2.weight"] = block_three_sd_orig.pop(f"{i}.f_2.weight") + block_three_sd_new[f"resnets.{i}.conv2.bias"] = block_three_sd_orig.pop(f"{i}.f_2.bias") + + if i == 0: + block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_three_sd_orig.pop(f"{i}.f_s.weight") + block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_three_sd_orig.pop(f"{i}.f_s.bias") + +block_three_sd_new["downsamplers.0.norm1.weight"] = block_three_sd_orig.pop("3.gn_1.weight") +block_three_sd_new["downsamplers.0.norm1.bias"] = block_three_sd_orig.pop("3.gn_1.bias") +block_three_sd_new["downsamplers.0.conv1.weight"] = block_three_sd_orig.pop("3.f_1.weight") +block_three_sd_new["downsamplers.0.conv1.bias"] = block_three_sd_orig.pop("3.f_1.bias") +block_three_sd_new["downsamplers.0.time_emb_proj.weight"] = block_three_sd_orig.pop("3.f_t.weight") +block_three_sd_new["downsamplers.0.time_emb_proj.bias"] = block_three_sd_orig.pop("3.f_t.bias") +block_three_sd_new["downsamplers.0.norm2.weight"] = block_three_sd_orig.pop("3.gn_2.weight") +block_three_sd_new["downsamplers.0.norm2.bias"] = block_three_sd_orig.pop("3.gn_2.bias") +block_three_sd_new["downsamplers.0.conv2.weight"] = block_three_sd_orig.pop("3.f_2.weight") +block_three_sd_new["downsamplers.0.conv2.bias"] = block_three_sd_orig.pop("3.f_2.bias") + +assert len(block_three_sd_orig) == 0 + +block_three = ResnetDownsampleBlock2D( + in_channels=640, + out_channels=1024, + temb_channels=1280, + num_layers=3, + add_downsample=True, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +block_three.load_state_dict(block_three_sd_new) + +print("DOWN BLOCK FOUR") + +block_four_sd_orig = model.down[3].state_dict() +block_four_sd_new = {} + +for i in range(3): + block_four_sd_new[f"resnets.{i}.norm1.weight"] = block_four_sd_orig.pop(f"{i}.gn_1.weight") + block_four_sd_new[f"resnets.{i}.norm1.bias"] = block_four_sd_orig.pop(f"{i}.gn_1.bias") + block_four_sd_new[f"resnets.{i}.conv1.weight"] = block_four_sd_orig.pop(f"{i}.f_1.weight") + block_four_sd_new[f"resnets.{i}.conv1.bias"] = block_four_sd_orig.pop(f"{i}.f_1.bias") + block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_four_sd_orig.pop(f"{i}.f_t.weight") + block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_four_sd_orig.pop(f"{i}.f_t.bias") + block_four_sd_new[f"resnets.{i}.norm2.weight"] = block_four_sd_orig.pop(f"{i}.gn_2.weight") + block_four_sd_new[f"resnets.{i}.norm2.bias"] = block_four_sd_orig.pop(f"{i}.gn_2.bias") + block_four_sd_new[f"resnets.{i}.conv2.weight"] = block_four_sd_orig.pop(f"{i}.f_2.weight") + block_four_sd_new[f"resnets.{i}.conv2.bias"] = block_four_sd_orig.pop(f"{i}.f_2.bias") + +assert len(block_four_sd_orig) == 0 + +block_four = ResnetDownsampleBlock2D( + in_channels=1024, + out_channels=1024, + temb_channels=1280, + num_layers=3, + add_downsample=False, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +block_four.load_state_dict(block_four_sd_new) + + +print("MID BLOCK 1") + +mid_block_one_sd_orig = model.mid.state_dict() +mid_block_one_sd_new = {} + +for i in range(2): + mid_block_one_sd_new[f"resnets.{i}.norm1.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.weight") + mid_block_one_sd_new[f"resnets.{i}.norm1.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.bias") + mid_block_one_sd_new[f"resnets.{i}.conv1.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_1.weight") + mid_block_one_sd_new[f"resnets.{i}.conv1.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_1.bias") + mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_t.weight") + mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_t.bias") + mid_block_one_sd_new[f"resnets.{i}.norm2.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.weight") + mid_block_one_sd_new[f"resnets.{i}.norm2.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.bias") + mid_block_one_sd_new[f"resnets.{i}.conv2.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_2.weight") + mid_block_one_sd_new[f"resnets.{i}.conv2.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_2.bias") + +assert len(mid_block_one_sd_orig) == 0 + +mid_block_one = UNetMidBlock2D( + in_channels=1024, + temb_channels=1280, + num_layers=1, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, + add_attention=False, +) + +mid_block_one.load_state_dict(mid_block_one_sd_new) + +print("UP BLOCK ONE") + +up_block_one_sd_orig = model.up[-1].state_dict() +up_block_one_sd_new = {} + +for i in range(4): + up_block_one_sd_new[f"resnets.{i}.norm1.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_1.weight") + up_block_one_sd_new[f"resnets.{i}.norm1.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_1.bias") + up_block_one_sd_new[f"resnets.{i}.conv1.weight"] = up_block_one_sd_orig.pop(f"{i}.f_1.weight") + up_block_one_sd_new[f"resnets.{i}.conv1.bias"] = up_block_one_sd_orig.pop(f"{i}.f_1.bias") + up_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_one_sd_orig.pop(f"{i}.f_t.weight") + up_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_one_sd_orig.pop(f"{i}.f_t.bias") + up_block_one_sd_new[f"resnets.{i}.norm2.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_2.weight") + up_block_one_sd_new[f"resnets.{i}.norm2.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_2.bias") + up_block_one_sd_new[f"resnets.{i}.conv2.weight"] = up_block_one_sd_orig.pop(f"{i}.f_2.weight") + up_block_one_sd_new[f"resnets.{i}.conv2.bias"] = up_block_one_sd_orig.pop(f"{i}.f_2.bias") + up_block_one_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_one_sd_orig.pop(f"{i}.f_s.weight") + up_block_one_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_one_sd_orig.pop(f"{i}.f_s.bias") + +up_block_one_sd_new["upsamplers.0.norm1.weight"] = up_block_one_sd_orig.pop("4.gn_1.weight") +up_block_one_sd_new["upsamplers.0.norm1.bias"] = up_block_one_sd_orig.pop("4.gn_1.bias") +up_block_one_sd_new["upsamplers.0.conv1.weight"] = up_block_one_sd_orig.pop("4.f_1.weight") +up_block_one_sd_new["upsamplers.0.conv1.bias"] = up_block_one_sd_orig.pop("4.f_1.bias") +up_block_one_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_one_sd_orig.pop("4.f_t.weight") +up_block_one_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_one_sd_orig.pop("4.f_t.bias") +up_block_one_sd_new["upsamplers.0.norm2.weight"] = up_block_one_sd_orig.pop("4.gn_2.weight") +up_block_one_sd_new["upsamplers.0.norm2.bias"] = up_block_one_sd_orig.pop("4.gn_2.bias") +up_block_one_sd_new["upsamplers.0.conv2.weight"] = up_block_one_sd_orig.pop("4.f_2.weight") +up_block_one_sd_new["upsamplers.0.conv2.bias"] = up_block_one_sd_orig.pop("4.f_2.bias") + +assert len(up_block_one_sd_orig) == 0 + +up_block_one = ResnetUpsampleBlock2D( + in_channels=1024, + prev_output_channel=1024, + out_channels=1024, + temb_channels=1280, + num_layers=4, + add_upsample=True, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +up_block_one.load_state_dict(up_block_one_sd_new) + +print("UP BLOCK TWO") + +up_block_two_sd_orig = model.up[-2].state_dict() +up_block_two_sd_new = {} + +for i in range(4): + up_block_two_sd_new[f"resnets.{i}.norm1.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_1.weight") + up_block_two_sd_new[f"resnets.{i}.norm1.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_1.bias") + up_block_two_sd_new[f"resnets.{i}.conv1.weight"] = up_block_two_sd_orig.pop(f"{i}.f_1.weight") + up_block_two_sd_new[f"resnets.{i}.conv1.bias"] = up_block_two_sd_orig.pop(f"{i}.f_1.bias") + up_block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_two_sd_orig.pop(f"{i}.f_t.weight") + up_block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_two_sd_orig.pop(f"{i}.f_t.bias") + up_block_two_sd_new[f"resnets.{i}.norm2.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_2.weight") + up_block_two_sd_new[f"resnets.{i}.norm2.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_2.bias") + up_block_two_sd_new[f"resnets.{i}.conv2.weight"] = up_block_two_sd_orig.pop(f"{i}.f_2.weight") + up_block_two_sd_new[f"resnets.{i}.conv2.bias"] = up_block_two_sd_orig.pop(f"{i}.f_2.bias") + up_block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_two_sd_orig.pop(f"{i}.f_s.weight") + up_block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_two_sd_orig.pop(f"{i}.f_s.bias") + +up_block_two_sd_new["upsamplers.0.norm1.weight"] = up_block_two_sd_orig.pop("4.gn_1.weight") +up_block_two_sd_new["upsamplers.0.norm1.bias"] = up_block_two_sd_orig.pop("4.gn_1.bias") +up_block_two_sd_new["upsamplers.0.conv1.weight"] = up_block_two_sd_orig.pop("4.f_1.weight") +up_block_two_sd_new["upsamplers.0.conv1.bias"] = up_block_two_sd_orig.pop("4.f_1.bias") +up_block_two_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_two_sd_orig.pop("4.f_t.weight") +up_block_two_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_two_sd_orig.pop("4.f_t.bias") +up_block_two_sd_new["upsamplers.0.norm2.weight"] = up_block_two_sd_orig.pop("4.gn_2.weight") +up_block_two_sd_new["upsamplers.0.norm2.bias"] = up_block_two_sd_orig.pop("4.gn_2.bias") +up_block_two_sd_new["upsamplers.0.conv2.weight"] = up_block_two_sd_orig.pop("4.f_2.weight") +up_block_two_sd_new["upsamplers.0.conv2.bias"] = up_block_two_sd_orig.pop("4.f_2.bias") + +assert len(up_block_two_sd_orig) == 0 + +up_block_two = ResnetUpsampleBlock2D( + in_channels=640, + prev_output_channel=1024, + out_channels=1024, + temb_channels=1280, + num_layers=4, + add_upsample=True, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +up_block_two.load_state_dict(up_block_two_sd_new) + +print("UP BLOCK THREE") + +up_block_three_sd_orig = model.up[-3].state_dict() +up_block_three_sd_new = {} + +for i in range(4): + up_block_three_sd_new[f"resnets.{i}.norm1.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_1.weight") + up_block_three_sd_new[f"resnets.{i}.norm1.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_1.bias") + up_block_three_sd_new[f"resnets.{i}.conv1.weight"] = up_block_three_sd_orig.pop(f"{i}.f_1.weight") + up_block_three_sd_new[f"resnets.{i}.conv1.bias"] = up_block_three_sd_orig.pop(f"{i}.f_1.bias") + up_block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_three_sd_orig.pop(f"{i}.f_t.weight") + up_block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_three_sd_orig.pop(f"{i}.f_t.bias") + up_block_three_sd_new[f"resnets.{i}.norm2.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_2.weight") + up_block_three_sd_new[f"resnets.{i}.norm2.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_2.bias") + up_block_three_sd_new[f"resnets.{i}.conv2.weight"] = up_block_three_sd_orig.pop(f"{i}.f_2.weight") + up_block_three_sd_new[f"resnets.{i}.conv2.bias"] = up_block_three_sd_orig.pop(f"{i}.f_2.bias") + up_block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_three_sd_orig.pop(f"{i}.f_s.weight") + up_block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_three_sd_orig.pop(f"{i}.f_s.bias") + +up_block_three_sd_new["upsamplers.0.norm1.weight"] = up_block_three_sd_orig.pop("4.gn_1.weight") +up_block_three_sd_new["upsamplers.0.norm1.bias"] = up_block_three_sd_orig.pop("4.gn_1.bias") +up_block_three_sd_new["upsamplers.0.conv1.weight"] = up_block_three_sd_orig.pop("4.f_1.weight") +up_block_three_sd_new["upsamplers.0.conv1.bias"] = up_block_three_sd_orig.pop("4.f_1.bias") +up_block_three_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_three_sd_orig.pop("4.f_t.weight") +up_block_three_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_three_sd_orig.pop("4.f_t.bias") +up_block_three_sd_new["upsamplers.0.norm2.weight"] = up_block_three_sd_orig.pop("4.gn_2.weight") +up_block_three_sd_new["upsamplers.0.norm2.bias"] = up_block_three_sd_orig.pop("4.gn_2.bias") +up_block_three_sd_new["upsamplers.0.conv2.weight"] = up_block_three_sd_orig.pop("4.f_2.weight") +up_block_three_sd_new["upsamplers.0.conv2.bias"] = up_block_three_sd_orig.pop("4.f_2.bias") + +assert len(up_block_three_sd_orig) == 0 + +up_block_three = ResnetUpsampleBlock2D( + in_channels=320, + prev_output_channel=1024, + out_channels=640, + temb_channels=1280, + num_layers=4, + add_upsample=True, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +up_block_three.load_state_dict(up_block_three_sd_new) + +print("UP BLOCK FOUR") + +up_block_four_sd_orig = model.up[-4].state_dict() +up_block_four_sd_new = {} + +for i in range(4): + up_block_four_sd_new[f"resnets.{i}.norm1.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_1.weight") + up_block_four_sd_new[f"resnets.{i}.norm1.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_1.bias") + up_block_four_sd_new[f"resnets.{i}.conv1.weight"] = up_block_four_sd_orig.pop(f"{i}.f_1.weight") + up_block_four_sd_new[f"resnets.{i}.conv1.bias"] = up_block_four_sd_orig.pop(f"{i}.f_1.bias") + up_block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_four_sd_orig.pop(f"{i}.f_t.weight") + up_block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_four_sd_orig.pop(f"{i}.f_t.bias") + up_block_four_sd_new[f"resnets.{i}.norm2.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_2.weight") + up_block_four_sd_new[f"resnets.{i}.norm2.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_2.bias") + up_block_four_sd_new[f"resnets.{i}.conv2.weight"] = up_block_four_sd_orig.pop(f"{i}.f_2.weight") + up_block_four_sd_new[f"resnets.{i}.conv2.bias"] = up_block_four_sd_orig.pop(f"{i}.f_2.bias") + up_block_four_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_four_sd_orig.pop(f"{i}.f_s.weight") + up_block_four_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_four_sd_orig.pop(f"{i}.f_s.bias") + +assert len(up_block_four_sd_orig) == 0 + +up_block_four = ResnetUpsampleBlock2D( + in_channels=320, + prev_output_channel=640, + out_channels=320, + temb_channels=1280, + num_layers=4, + add_upsample=False, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +up_block_four.load_state_dict(up_block_four_sd_new) + +print("initial projection (conv_in)") + +conv_in_sd_orig = model.embed_image.state_dict() +conv_in_sd_new = {} + +conv_in_sd_new["weight"] = conv_in_sd_orig.pop("f.weight") +conv_in_sd_new["bias"] = conv_in_sd_orig.pop("f.bias") + +assert len(conv_in_sd_orig) == 0 + +block_out_channels = [320, 640, 1024, 1024] + +in_channels = 7 +conv_in_kernel = 3 +conv_in_padding = (conv_in_kernel - 1) // 2 +conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding) + +conv_in.load_state_dict(conv_in_sd_new) + +print("out projection (conv_out) (conv_norm_out)") +out_channels = 6 +norm_num_groups = 32 +norm_eps = 1e-5 +act_fn = "silu" +conv_out_kernel = 3 +conv_out_padding = (conv_out_kernel - 1) // 2 +conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) +# uses torch.functional in orig +# conv_act = get_activation(act_fn) +conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding) + +conv_norm_out.load_state_dict(model.output.gn.state_dict()) +conv_out.load_state_dict(model.output.f.state_dict()) + +print("timestep projection (time_proj) (time_embedding)") + +f1_sd = model.embed_time.f_1.state_dict() +f2_sd = model.embed_time.f_2.state_dict() + +time_embedding_sd = { + "linear_1.weight": f1_sd.pop("weight"), + "linear_1.bias": f1_sd.pop("bias"), + "linear_2.weight": f2_sd.pop("weight"), + "linear_2.bias": f2_sd.pop("bias"), +} + +assert len(f1_sd) == 0 +assert len(f2_sd) == 0 + +time_embedding_type = "learned" +num_train_timesteps = 1024 +time_embedding_dim = 1280 + +time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0]) +timestep_input_dim = block_out_channels[0] + +time_embedding = TimestepEmbedding(timestep_input_dim, time_embedding_dim) + +time_proj.load_state_dict(model.embed_time.emb.state_dict()) +time_embedding.load_state_dict(time_embedding_sd) + +print("CONVERT") + +time_embedding.to("cuda") +time_proj.to("cuda") +conv_in.to("cuda") + +block_one.to("cuda") +block_two.to("cuda") +block_three.to("cuda") +block_four.to("cuda") + +mid_block_one.to("cuda") + +up_block_one.to("cuda") +up_block_two.to("cuda") +up_block_three.to("cuda") +up_block_four.to("cuda") + +conv_norm_out.to("cuda") +conv_out.to("cuda") + +model.time_proj = time_proj +model.time_embedding = time_embedding +model.embed_image = conv_in + +model.down[0] = block_one +model.down[1] = block_two +model.down[2] = block_three +model.down[3] = block_four + +model.mid = mid_block_one + +model.up[-1] = up_block_one +model.up[-2] = up_block_two +model.up[-3] = up_block_three +model.up[-4] = up_block_four + +model.output.gn = conv_norm_out +model.output.f = conv_out + +model.converted = True + +sample_consistency_new = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0)) +save_image(sample_consistency_new, "con_new.png") + +assert (sample_consistency_orig == sample_consistency_new).all() + +print("making unet") + +unet = UNet2DModel( + in_channels=in_channels, + out_channels=out_channels, + down_block_types=( + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ), + up_block_types=( + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ), + block_out_channels=block_out_channels, + layers_per_block=3, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + resnet_time_scale_shift="scale_shift", + time_embedding_type="learned", + num_train_timesteps=num_train_timesteps, + add_attention=False, +) + +unet_state_dict = {} + + +def add_state_dict(prefix, mod): + for k, v in mod.state_dict().items(): + unet_state_dict[f"{prefix}.{k}"] = v + + +add_state_dict("conv_in", conv_in) +add_state_dict("time_proj", time_proj) +add_state_dict("time_embedding", time_embedding) +add_state_dict("down_blocks.0", block_one) +add_state_dict("down_blocks.1", block_two) +add_state_dict("down_blocks.2", block_three) +add_state_dict("down_blocks.3", block_four) +add_state_dict("mid_block", mid_block_one) +add_state_dict("up_blocks.0", up_block_one) +add_state_dict("up_blocks.1", up_block_two) +add_state_dict("up_blocks.2", up_block_three) +add_state_dict("up_blocks.3", up_block_four) +add_state_dict("conv_norm_out", conv_norm_out) +add_state_dict("conv_out", conv_out) + +unet.load_state_dict(unet_state_dict) + +print("running with diffusers unet") + +unet.to("cuda") + +decoder_consistency.ckpt = unet + +sample_consistency_new_2 = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0)) +save_image(sample_consistency_new_2, "con_new_2.png") + +assert (sample_consistency_orig == sample_consistency_new_2).all() + +print("running with diffusers model") + +Encoder.old_constructor = Encoder.__init__ + + +def new_constructor(self, **kwargs): + self.old_constructor(**kwargs) + self.constructor_arguments = kwargs + + +Encoder.__init__ = new_constructor + + +vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") +consistency_vae = ConsistencyDecoderVAE( + encoder_args=vae.encoder.constructor_arguments, + decoder_args=unet.config, + scaling_factor=vae.config.scaling_factor, + block_out_channels=vae.config.block_out_channels, + latent_channels=vae.config.latent_channels, +) +consistency_vae.encoder.load_state_dict(vae.encoder.state_dict()) +consistency_vae.quant_conv.load_state_dict(vae.quant_conv.state_dict()) +consistency_vae.decoder_unet.load_state_dict(unet.state_dict()) + +consistency_vae.to(dtype=torch.float16, device="cuda") + +sample_consistency_new_3 = consistency_vae.decode( + 0.18215 * latent, generator=torch.Generator("cpu").manual_seed(0) +).sample + +print("max difference") +print((sample_consistency_orig - sample_consistency_new_3).abs().max()) +print("total difference") +print((sample_consistency_orig - sample_consistency_new_3).abs().sum()) +# assert (sample_consistency_orig == sample_consistency_new_3).all() + +print("running with diffusers pipeline") + +pipe = DiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16 +) +pipe.to("cuda") + +pipe("horse", generator=torch.Generator("cpu").manual_seed(0)).images[0].save("horse.png") + + +if args.save_pretrained is not None: + consistency_vae.save_pretrained(args.save_pretrained) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 4291e911ac74..b749355fe106 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -77,6 +77,7 @@ "AsymmetricAutoencoderKL", "AutoencoderKL", "AutoencoderTiny", + "ConsistencyDecoderVAE", "ControlNetModel", "ModelMixin", "MotionAdapter", @@ -443,6 +444,7 @@ AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderTiny, + ConsistencyDecoderVAE, ControlNetModel, ModelMixin, MotionAdapter, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f807353312d1..d45f56d43c32 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -24,6 +24,7 @@ _import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoder_tiny"] = ["AutoencoderTiny"] + _import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["controlnet"] = ["ControlNetModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["modeling_utils"] = ["ModelMixin"] @@ -50,6 +51,7 @@ from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL from .autoencoder_tiny import AutoencoderTiny + from .consistency_decoder_vae import ConsistencyDecoderVAE from .controlnet import ControlNetModel from .dual_transformer_2d import DualTransformer2DModel from .modeling_utils import ModelMixin diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 80d2cccd536d..ac616530a66a 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -294,7 +294,9 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod return DecoderOutput(sample=dec) @apply_forward_hook - def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: """ Decode a batch of images. diff --git a/src/diffusers/models/consistency_decoder_vae.py b/src/diffusers/models/consistency_decoder_vae.py new file mode 100644 index 000000000000..a2bcf1b8b78e --- /dev/null +++ b/src/diffusers/models/consistency_decoder_vae.py @@ -0,0 +1,365 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..schedulers import ConsistencyDecoderScheduler +from ..utils import BaseOutput +from ..utils.accelerate_utils import apply_forward_hook +from ..utils.torch_utils import randn_tensor +from .attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from .modeling_utils import ModelMixin +from .unet_2d import UNet2DModel +from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder + + +@dataclass +class ConsistencyDecoderVAEOutput(BaseOutput): + """ + Output of encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "DiagonalGaussianDistribution" + + +class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): + r""" + The consistency decoder used with DALL-E 3. + + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline, ConsistencyDecoderVAE + + >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=pipe.torch_dtype) + >>> pipe = StableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16 + ... ).to("cuda") + + >>> pipe("horse", generator=torch.manual_seed(0)).images + ``` + """ + + @register_to_config + def __init__(self, encoder_args, decoder_args, scaling_factor, block_out_channels, latent_channels): + super().__init__() + self.encoder = Encoder(**encoder_args) + self.decoder_unet = UNet2DModel(**decoder_args) + self.decoder_scheduler = ConsistencyDecoderScheduler() + self.register_buffer( + "means", + torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None], + persistent=False, + ) + self.register_buffer( + "stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + + self.use_slicing = False + self.use_tiling = False + + # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = use_tiling + + # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.enable_tiling(False) + + # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain + tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple` + is returned. + """ + if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.tiled_encode(x, return_dict=return_dict) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return ConsistencyDecoderVAEOutput(latent_dist=posterior) + + @apply_forward_hook + def decode( + self, + z: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + num_inference_steps=2, + ) -> Union[DecoderOutput, torch.FloatTensor]: + z = (z * self.config.scaling_factor - self.means) / self.stds + + scale_factor = 2 ** (len(self.config.block_out_channels) - 1) + z = F.interpolate(z, mode="nearest", scale_factor=scale_factor) + + batch_size, _, height, width = z.shape + + self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device) + + x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor( + (batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device + ) + + for t in self.decoder_scheduler.timesteps: + model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1) + model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :] + prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample + x_t = prev_sample + + x_0 = x_t + + if not return_dict: + return (x_0,) + + return DecoderOutput(sample=x_0) + + # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v + def blend_v(self, a, b, blend_extent): + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h + def blend_h(self, a, b, blend_extent): + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a + plain tuple. + + Returns: + [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`: + If return_dict is True, a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, + otherwise a plain `tuple` is returned. + """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return ConsistencyDecoderVAEOutput(latent_dist=posterior) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, generator=generator).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 38e26422e2a7..0531d8aae783 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -117,6 +117,7 @@ def __init__( add_attention: bool = True, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, + num_train_timesteps: Optional[int] = None, ): super().__init__() @@ -144,6 +145,9 @@ def __init__( elif time_embedding_type == "positional": self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] + elif time_embedding_type == "learned": + self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0]) + timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index a73dc22a146c..a9f0de22f63f 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -852,7 +852,9 @@ def __call__( callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index ea4a3128dee3..74d5a15a98d3 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -893,7 +893,9 @@ def __call__( callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/consistency_models/__init__.py b/src/diffusers/pipelines/consistency_models/__init__.py index 053a3666263f..162d91c010ac 100644 --- a/src/diffusers/pipelines/consistency_models/__init__.py +++ b/src/diffusers/pipelines/consistency_models/__init__.py @@ -6,7 +6,9 @@ ) -_import_structure = {"pipeline_consistency_models": ["ConsistencyModelPipeline"]} +_import_structure = { + "pipeline_consistency_models": ["ConsistencyModelPipeline"], +} if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pipeline_consistency_models import ConsistencyModelPipeline diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index de1b1fd93c7f..6465250a762a 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -1,3 +1,17 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Callable, List, Optional, Union import torch diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 6944d9331253..04ca51b19f05 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -1058,7 +1058,9 @@ def __call__( torch.cuda.empty_cache() if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index b692d936c5b7..8683d09e118d 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -1138,7 +1138,9 @@ def __call__( torch.cuda.empty_cache() if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 3e0b07bf1694..399cfdcf9c2c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -1405,7 +1405,9 @@ def __call__( torch.cuda.empty_cache() if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f08100b12358..2a79927de69c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -838,7 +838,9 @@ def __call__( callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 583e6046b2e1..a01ace26b3f0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -885,7 +885,9 @@ def __call__( callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index ca7d62fd5077..bd3ccce70556 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -1159,7 +1159,9 @@ def __call__( init_image = self._encode_vae_image(init_image, generator=generator) mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) condition_kwargs = {"image": init_image_condition, "mask": mask_condition} - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0] + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs + )[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 85fd9d25e5da..5e5102e589d4 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -38,6 +38,7 @@ _dummy_modules.update(get_objects_from_module(dummy_pt_objects)) else: + _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] _import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"] @@ -128,6 +129,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler diff --git a/src/diffusers/schedulers/scheduling_consistency_decoder.py b/src/diffusers/schedulers/scheduling_consistency_decoder.py new file mode 100644 index 000000000000..69ca8a1737ec --- /dev/null +++ b/src/diffusers/schedulers/scheduling_consistency_decoder.py @@ -0,0 +1,180 @@ +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +@dataclass +class ConsistencyDecoderSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin): + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1024, + sigma_data: float = 0.5, + ): + betas = betas_for_alpha_bar(num_train_timesteps) + + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + + sigmas = torch.sqrt(1.0 / alphas_cumprod - 1) + + sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) + + self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2) + self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5 + self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5 + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + ): + if num_inference_steps != 2: + raise ValueError("Currently more than 2 inference steps are not supported.") + + self.timesteps = torch.tensor([1008, 512], dtype=torch.long, device=device) + self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device) + self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device) + self.c_skip = self.c_skip.to(device) + self.c_out = self.c_out.to(device) + self.c_in = self.c_in.to(device) + + @property + def init_noise_sigma(self): + return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]] + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample * self.c_in[timestep] + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + timestep (`float`): + The current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a + [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise + a tuple is returned where the first element is the sample tensor. + """ + x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample + + timestep_idx = torch.where(self.timesteps == timestep)[0] + + if timestep_idx == len(self.timesteps) - 1: + prev_sample = x_0 + else: + noise = randn_tensor(x_0.shape, generator=generator, dtype=x_0.dtype, device=x_0.device) + prev_sample = ( + self.sqrt_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * x_0 + + self.sqrt_one_minus_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * noise + ) + + if not return_dict: + return (prev_sample,) + + return ConsistencyDecoderSchedulerOutput(prev_sample=prev_sample) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index d6d74a89cafb..090b1081fdaf 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -47,6 +47,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ConsistencyDecoderVAE(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ControlNetModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index e4ecb59121f7..961147839461 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -196,11 +196,15 @@ def test_forward_with_norm_groups(self): class ModelTesterMixin: main_input_name = None # overwrite in model specific tester class base_precision = 1e-3 + forward_requires_fresh_args = False def test_from_save_pretrained(self, expected_max_diff=5e-5): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + if self.forward_requires_fresh_args: + model = self.model_class(**self.init_dict) + else: + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) - model = self.model_class(**init_dict) if hasattr(model, "set_default_attn_processor"): model.set_default_attn_processor() model.to(torch_device) @@ -214,11 +218,18 @@ def test_from_save_pretrained(self, expected_max_diff=5e-5): new_model.to(torch_device) with torch.no_grad(): - image = model(**inputs_dict) + if self.forward_requires_fresh_args: + image = model(**self.inputs_dict(0)) + else: + image = model(**inputs_dict) + if isinstance(image, dict): image = image.to_tuple()[0] - new_image = new_model(**inputs_dict) + if self.forward_requires_fresh_args: + new_image = new_model(**self.inputs_dict(0)) + else: + new_image = new_model(**inputs_dict) if isinstance(new_image, dict): new_image = new_image.to_tuple()[0] @@ -275,8 +286,11 @@ def test_getattr_is_correct(self): ) def test_set_xformers_attn_processor_for_determinism(self): torch.use_deterministic_algorithms(False) - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) + if self.forward_requires_fresh_args: + model = self.model_class(**self.init_dict) + else: + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) model.to(torch_device) if not hasattr(model, "set_attn_processor"): @@ -286,17 +300,26 @@ def test_set_xformers_attn_processor_for_determinism(self): model.set_default_attn_processor() assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values()) with torch.no_grad(): - output = model(**inputs_dict)[0] + if self.forward_requires_fresh_args: + output = model(**self.inputs_dict(0))[0] + else: + output = model(**inputs_dict)[0] model.enable_xformers_memory_efficient_attention() assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values()) with torch.no_grad(): - output_2 = model(**inputs_dict)[0] + if self.forward_requires_fresh_args: + output_2 = model(**self.inputs_dict(0))[0] + else: + output_2 = model(**inputs_dict)[0] model.set_attn_processor(XFormersAttnProcessor()) assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values()) with torch.no_grad(): - output_3 = model(**inputs_dict)[0] + if self.forward_requires_fresh_args: + output_3 = model(**self.inputs_dict(0))[0] + else: + output_3 = model(**inputs_dict)[0] torch.use_deterministic_algorithms(True) @@ -307,8 +330,12 @@ def test_set_xformers_attn_processor_for_determinism(self): @require_torch_gpu def test_set_attn_processor_for_determinism(self): torch.use_deterministic_algorithms(False) - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) + if self.forward_requires_fresh_args: + model = self.model_class(**self.init_dict) + else: + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) if not hasattr(model, "set_attn_processor"): @@ -317,22 +344,34 @@ def test_set_attn_processor_for_determinism(self): assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values()) with torch.no_grad(): - output_1 = model(**inputs_dict)[0] + if self.forward_requires_fresh_args: + output_1 = model(**self.inputs_dict(0))[0] + else: + output_1 = model(**inputs_dict)[0] model.set_default_attn_processor() assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values()) with torch.no_grad(): - output_2 = model(**inputs_dict)[0] + if self.forward_requires_fresh_args: + output_2 = model(**self.inputs_dict(0))[0] + else: + output_2 = model(**inputs_dict)[0] model.set_attn_processor(AttnProcessor2_0()) assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values()) with torch.no_grad(): - output_4 = model(**inputs_dict)[0] + if self.forward_requires_fresh_args: + output_4 = model(**self.inputs_dict(0))[0] + else: + output_4 = model(**inputs_dict)[0] model.set_attn_processor(AttnProcessor()) assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values()) with torch.no_grad(): - output_5 = model(**inputs_dict)[0] + if self.forward_requires_fresh_args: + output_5 = model(**self.inputs_dict(0))[0] + else: + output_5 = model(**inputs_dict)[0] torch.use_deterministic_algorithms(True) @@ -342,9 +381,12 @@ def test_set_attn_processor_for_determinism(self): assert torch.allclose(output_2, output_5, atol=self.base_precision) def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + if self.forward_requires_fresh_args: + model = self.model_class(**self.init_dict) + else: + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) - model = self.model_class(**init_dict) if hasattr(model, "set_default_attn_processor"): model.set_default_attn_processor() @@ -367,11 +409,17 @@ def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): new_model.to(torch_device) with torch.no_grad(): - image = model(**inputs_dict) + if self.forward_requires_fresh_args: + image = model(**self.inputs_dict(0)) + else: + image = model(**inputs_dict) if isinstance(image, dict): image = image.to_tuple()[0] - new_image = new_model(**inputs_dict) + if self.forward_requires_fresh_args: + new_image = new_model(**self.inputs_dict(0)) + else: + new_image = new_model(**inputs_dict) if isinstance(new_image, dict): new_image = new_image.to_tuple()[0] @@ -405,17 +453,26 @@ def test_from_save_pretrained_dtype(self): assert new_model.dtype == dtype def test_determinism(self, expected_max_diff=1e-5): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) + if self.forward_requires_fresh_args: + model = self.model_class(**self.init_dict) + else: + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) model.to(torch_device) model.eval() with torch.no_grad(): - first = model(**inputs_dict) + if self.forward_requires_fresh_args: + first = model(**self.inputs_dict(0)) + else: + first = model(**inputs_dict) if isinstance(first, dict): first = first.to_tuple()[0] - second = model(**inputs_dict) + if self.forward_requires_fresh_args: + second = model(**self.inputs_dict(0)) + else: + second = model(**inputs_dict) if isinstance(second, dict): second = second.to_tuple()[0] @@ -548,15 +605,22 @@ def recursive_check(tuple_object, dict_object): ), ) - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + if self.forward_requires_fresh_args: + model = self.model_class(**self.init_dict) + else: + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) - model = self.model_class(**init_dict) model.to(torch_device) model.eval() with torch.no_grad(): - outputs_dict = model(**inputs_dict) - outputs_tuple = model(**inputs_dict, return_dict=False) + if self.forward_requires_fresh_args: + outputs_dict = model(**self.inputs_dict(0)) + outputs_tuple = model(**self.inputs_dict(0), return_dict=False) + else: + outputs_dict = model(**inputs_dict) + outputs_tuple = model(**inputs_dict, return_dict=False) recursive_check(outputs_tuple, outputs_dict) diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index fe2bcdb0af35..1f5b847dd157 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -16,11 +16,19 @@ import gc import unittest +import numpy as np import torch from parameterized import parameterized -from diffusers import AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderTiny +from diffusers import ( + AsymmetricAutoencoderKL, + AutoencoderKL, + AutoencoderTiny, + ConsistencyDecoderVAE, + StableDiffusionPipeline, +) from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.loading_utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, @@ -30,6 +38,7 @@ torch_all_close, torch_device, ) +from diffusers.utils.torch_utils import randn_tensor from .test_modeling_common import ModelTesterMixin, UNetTesterMixin @@ -269,6 +278,79 @@ def test_outputs_equivalence(self): pass +class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase): + model_class = ConsistencyDecoderVAE + main_input_name = "sample" + base_precision = 1e-2 + forward_requires_fresh_args = True + + def inputs_dict(self, seed=None): + generator = torch.Generator("cpu") + if seed is not None: + generator.manual_seed(0) + image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device)) + + return {"sample": image, "generator": generator} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + @property + def init_dict(self): + return { + "encoder_args": { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 4, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + }, + "decoder_args": { + "act_fn": "silu", + "add_attention": False, + "block_out_channels": [32, 64], + "down_block_types": [ + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ], + "downsample_padding": 1, + "downsample_type": "conv", + "dropout": 0.0, + "in_channels": 7, + "layers_per_block": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_train_timesteps": 1024, + "out_channels": 6, + "resnet_time_scale_shift": "scale_shift", + "time_embedding_type": "learned", + "up_block_types": [ + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ], + "upsample_type": "conv", + }, + "scaling_factor": 1, + "block_out_channels": [32, 64], + "latent_channels": 4, + } + + def prepare_init_args_and_inputs_for_common(self): + return self.init_dict, self.inputs_dict() + + @unittest.skip + def test_training(self): + ... + + @unittest.skip + def test_ema_training(self): + ... + + @slow class AutoencoderTinyIntegrationTests(unittest.TestCase): def tearDown(self): @@ -721,3 +803,94 @@ def test_stable_diffusion_encode_sample(self, seed, expected_slice): tolerance = 3e-3 if torch_device != "mps" else 1e-2 assert torch_all_close(output_slice, expected_output_slice, atol=tolerance) + + +@slow +class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_encode_decode(self): + vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update + vae.to(torch_device) + + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ).resize((256, 256)) + image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[ + None, :, :, : + ].cuda() + + latent = vae.encode(image).latent_dist.mean + + sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample + + actual_output = sample[0, :2, :2, :2].flatten().cpu() + expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024]) + + assert torch_all_close(actual_output, expected_output, atol=5e-3) + + def test_sd(self): + vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None) + pipe.to(torch_device) + + out = pipe( + "horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0) + ).images[0] + + actual_output = out[:2, :2, :2].flatten().cpu() + expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759]) + + assert torch_all_close(actual_output, expected_output, atol=5e-3) + + def test_encode_decode_f16(self): + vae = ConsistencyDecoderVAE.from_pretrained( + "openai/consistency-decoder", torch_dtype=torch.float16 + ) # TODO - update + vae.to(torch_device) + + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ).resize((256, 256)) + image = ( + torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :] + .half() + .cuda() + ) + + latent = vae.encode(image).latent_dist.mean + + sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample + + actual_output = sample[0, :2, :2, :2].flatten().cpu() + expected_output = torch.tensor( + [-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16 + ) + + assert torch_all_close(actual_output, expected_output, atol=5e-3) + + def test_sd_f16(self): + vae = ConsistencyDecoderVAE.from_pretrained( + "openai/consistency-decoder", torch_dtype=torch.float16 + ) # TODO - update + pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None + ) + pipe.to(torch_device) + + out = pipe( + "horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0) + ).images[0] + + actual_output = out[:2, :2, :2].flatten().cpu() + expected_output = torch.tensor( + [0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16 + ) + + assert torch_all_close(actual_output, expected_output, atol=5e-3) From bf406ea8869b2cfeb826ef32e9c96367ea2185e9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 9 Nov 2023 13:10:24 +0100 Subject: [PATCH 3/3] Correct consist dec (#5722) * uP * Update src/diffusers/models/consistency_decoder_vae.py * uP * uP --- src/diffusers/models/autoencoder_asym_kl.py | 1 + src/diffusers/models/autoencoder_tiny.py | 6 +- .../models/consistency_decoder_vae.py | 71 ++++++++++++++++++- tests/models/test_models_vae.py | 55 ++++++-------- 4 files changed, 96 insertions(+), 37 deletions(-) diff --git a/src/diffusers/models/autoencoder_asym_kl.py b/src/diffusers/models/autoencoder_asym_kl.py index d8099120918b..9f0fa62d34cd 100644 --- a/src/diffusers/models/autoencoder_asym_kl.py +++ b/src/diffusers/models/autoencoder_asym_kl.py @@ -138,6 +138,7 @@ def _decode( def decode( self, z: torch.FloatTensor, + generator: Optional[torch.Generator] = None, image: Optional[torch.FloatTensor] = None, mask: Optional[torch.FloatTensor] = None, return_dict: bool = True, diff --git a/src/diffusers/models/autoencoder_tiny.py b/src/diffusers/models/autoencoder_tiny.py index 407b1906bba4..15bd53ff99d6 100644 --- a/src/diffusers/models/autoencoder_tiny.py +++ b/src/diffusers/models/autoencoder_tiny.py @@ -14,7 +14,7 @@ from dataclasses import dataclass -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch @@ -307,7 +307,9 @@ def encode( return AutoencoderTinyOutput(latents=output) @apply_forward_hook - def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: + def decode( + self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True + ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: if self.use_slicing and x.shape[0] > 1: output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)] output = torch.cat(output) diff --git a/src/diffusers/models/consistency_decoder_vae.py b/src/diffusers/models/consistency_decoder_vae.py index a2bcf1b8b78e..9cb7c60bbbc2 100644 --- a/src/diffusers/models/consistency_decoder_vae.py +++ b/src/diffusers/models/consistency_decoder_vae.py @@ -68,11 +68,76 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): """ @register_to_config - def __init__(self, encoder_args, decoder_args, scaling_factor, block_out_channels, latent_channels): + def __init__( + self, + scaling_factor=0.18215, + latent_channels=4, + encoder_act_fn="silu", + encoder_block_out_channels=(128, 256, 512, 512), + encoder_double_z=True, + encoder_down_block_types=( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + encoder_in_channels=3, + encoder_layers_per_block=2, + encoder_norm_num_groups=32, + encoder_out_channels=4, + decoder_add_attention=False, + decoder_block_out_channels=(320, 640, 1024, 1024), + decoder_down_block_types=( + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ), + decoder_downsample_padding=1, + decoder_in_channels=7, + decoder_layers_per_block=3, + decoder_norm_eps=1e-05, + decoder_norm_num_groups=32, + decoder_num_train_timesteps=1024, + decoder_out_channels=6, + decoder_resnet_time_scale_shift="scale_shift", + decoder_time_embedding_type="learned", + decoder_up_block_types=( + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ), + ): super().__init__() - self.encoder = Encoder(**encoder_args) - self.decoder_unet = UNet2DModel(**decoder_args) + self.encoder = Encoder( + act_fn=encoder_act_fn, + block_out_channels=encoder_block_out_channels, + double_z=encoder_double_z, + down_block_types=encoder_down_block_types, + in_channels=encoder_in_channels, + layers_per_block=encoder_layers_per_block, + norm_num_groups=encoder_norm_num_groups, + out_channels=encoder_out_channels, + ) + + self.decoder_unet = UNet2DModel( + add_attention=decoder_add_attention, + block_out_channels=decoder_block_out_channels, + down_block_types=decoder_down_block_types, + downsample_padding=decoder_downsample_padding, + in_channels=decoder_in_channels, + layers_per_block=decoder_layers_per_block, + norm_eps=decoder_norm_eps, + norm_num_groups=decoder_norm_num_groups, + num_train_timesteps=decoder_num_train_timesteps, + out_channels=decoder_out_channels, + resnet_time_scale_shift=decoder_resnet_time_scale_shift, + time_embedding_type=decoder_time_embedding_type, + up_block_types=decoder_up_block_types, + ) self.decoder_scheduler = ConsistencyDecoderScheduler() + self.register_to_config(block_out_channels=encoder_block_out_channels) self.register_buffer( "means", torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None], diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 1f5b847dd157..3b698624ff87 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -303,39 +303,30 @@ def output_shape(self): @property def init_dict(self): return { - "encoder_args": { - "block_out_channels": [32, 64], - "in_channels": 3, - "out_channels": 4, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - }, - "decoder_args": { - "act_fn": "silu", - "add_attention": False, - "block_out_channels": [32, 64], - "down_block_types": [ - "ResnetDownsampleBlock2D", - "ResnetDownsampleBlock2D", - ], - "downsample_padding": 1, - "downsample_type": "conv", - "dropout": 0.0, - "in_channels": 7, - "layers_per_block": 1, - "norm_eps": 1e-05, - "norm_num_groups": 32, - "num_train_timesteps": 1024, - "out_channels": 6, - "resnet_time_scale_shift": "scale_shift", - "time_embedding_type": "learned", - "up_block_types": [ - "ResnetUpsampleBlock2D", - "ResnetUpsampleBlock2D", - ], - "upsample_type": "conv", - }, + "encoder_block_out_channels": [32, 64], + "encoder_in_channels": 3, + "encoder_out_channels": 4, + "encoder_down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "decoder_add_attention": False, + "decoder_block_out_channels": [32, 64], + "decoder_down_block_types": [ + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ], + "decoder_downsample_padding": 1, + "decoder_in_channels": 7, + "decoder_layers_per_block": 1, + "decoder_norm_eps": 1e-05, + "decoder_norm_num_groups": 32, + "decoder_num_train_timesteps": 1024, + "decoder_out_channels": 6, + "decoder_resnet_time_scale_shift": "scale_shift", + "decoder_time_embedding_type": "learned", + "decoder_up_block_types": [ + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ], "scaling_factor": 1, - "block_out_channels": [32, 64], "latent_channels": 4, }