diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..4e82b44 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +imgs filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..59c3fb0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +models +flux-layer-outputs +*captions.json \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..c4f0ae6 --- /dev/null +++ b/README.md @@ -0,0 +1,56 @@ + +# LayerDiffuse-Flux +This repo is a Flux version implementation of LayerDiffuse ([LayerDiffuse](https://github.com/lllyasviel/LayerDiffuse)). + +We train a **new transparent vae** to adapt to Flux and train a **lora** to finetune Flux to generate transparent images. + +|![](./imgs/top_examples/bottle_vis.png)|![](./imgs/top_examples/boat_vis.png)|![](./imgs/top_examples/cat_doctor_vis.png)|![](./imgs/top_examples/dragonball_vis.png)|![](./imgs/top_examples/dress_vis.png)|![](./imgs/top_examples/half_vis.png)|![](./imgs/top_examples/stickers_vis.png)| +|:-:|:-:|:-:|:-:|:-:|:-:|:-:| +## Usage + ++ Clone this repository. +```shell +git clone ... +cd LayerDiffuse-Flux +``` ++ download weights +``` shell + +``` + +## Flux Transparent T2I +### demo +```shell +python demo_t2i.py --ckpt_path /your/path/to/FLUX.1 dev +``` +### examples + +| Examples: top to bottom: flux, flux-layer(ours), sdxl, sdxl-layer([LayerDiffuse](https://github.com/lllyasviel/LayerDiffuse)) | +|------------------------------------| +|![](./imgs/flux_layer_t2i_examples/test0_10_vis.jpg)| +|![](./imgs/flux_layer_t2i_examples/test10_20_vis.jpg)| +|![](./imgs/flux_layer_t2i_examples/test20_30_vis.jpg)| +|![](./imgs/flux_layer_t2i_examples/test30_40_vis.jpg)| +|![](./imgs/flux_layer_t2i_examples/test40_50_vis.jpg)| + + + +## Flux Transparent I2I +```shell +python demo_i2i.py --ckpt_path /your/path/to/FLUX.1 dev --image "./imgs/causal_cut.png" +``` +Prompt: "a handsome man with curly hair, high quality" + +Strength: 0.9 + +| Input (Transparent image) | Output (Transparent image) | +|------------------------------------|--------------------------------------------| +| ![img](imgs/causal_cut_vis.png) | ![img](imgs/causal_cut_output_vis.png) | + +## Acknowledgements +Thanks lllyasviel for their great work [LayerDiffuse](https://github.com/lllyasviel/LayerDiffuse) + +## Contact +If you have any questions about the code, please do not hesitate to contact us! + +Email: sunshuang1@xiaohongshu.com, xiangqiang1601@163.com diff --git a/demo_i2i.py b/demo_i2i.py new file mode 100644 index 0000000..2ea968e --- /dev/null +++ b/demo_i2i.py @@ -0,0 +1,89 @@ +import torch +import argparse +import os +import datetime +from pipeline_flux_img2img import FluxImg2ImgPipeline +from lib_layerdiffuse.vae import TransparentVAE, pad_rgb +from PIL import Image +import numpy as np +from torchvision import transforms +from safetensors.torch import load_file +from PIL import Image, ImageDraw, ImageFont + + +def generate_img(pipe, trans_vae, args): + original_image = (transforms.ToTensor()(Image.open(args.image))).unsqueeze(0) + padding_feed = [x for x in original_image.movedim(1, -1).float().cpu().numpy()] + list_of_np_rgb_padded = [pad_rgb(x) for x in padding_feed] + rgb_padded_bchw_01 = torch.from_numpy(np.stack(list_of_np_rgb_padded, axis=0)).float().movedim(-1, 1).to(original_image.device) + original_image_feed = original_image.clone() + original_image_feed[:, :3, :, :] = original_image_feed[:, :3, :, :] * 2.0 - 1.0 + original_image_rgb = original_image_feed[:, :3, :, :] * original_image_feed[:, 3, :, :] + + original_image_feed = original_image_feed.to("cuda") + original_image_rgb = original_image_rgb.to("cuda") + rgb_padded_bchw_01 = rgb_padded_bchw_01.to("cuda") + trans_vae.to(torch.device('cuda')) + rng = torch.Generator("cuda").manual_seed(args.seed) + + initial_latent = trans_vae.encode(original_image_feed, original_image_rgb, rgb_padded_bchw_01, use_offset=True) + + latents = pipe( + latents=initial_latent, + image=original_image, + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.steps, + output_type="latent", + generator=rng, + guidance_scale=args.guidance, + strength=args.strength, + ).images + + latents = pipe._unpack_latents(latents, args.height, args.width, pipe.vae_scale_factor) + latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor + + with torch.no_grad(): + original_x, x = trans_vae.decode(latents) + + x = x.clamp(0, 1) + x = x.permute(0, 2, 3, 1) + img = Image.fromarray((x*255).float().cpu().numpy().astype(np.uint8)[0]) + return img + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--trans_vae", type=str, default="./models/TransparentVAE.pth") + parser.add_argument("--output_dir", type=str, default="./flux-layer-outputs") + parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype") + parser.add_argument("--seed", type=int, default=43) + parser.add_argument("--steps", type=int, default=50) + parser.add_argument("--guidance", type=float, default=7.0) + parser.add_argument("--strength", type=float, default=0.8) + parser.add_argument("--prompt", type=str, default="a handsome man with curly hair, high quality") + parser.add_argument( + "--lora_weights", + type=str, + default="./models/layerlora.safetensors", + ) + parser.add_argument("--width", type=int, default=1024) + parser.add_argument("--height", type=int, default=1024) + parser.add_argument("--image", type=str, default="./imgs/causal_cut.png") + args = parser.parse_args() + + pipe = FluxImg2ImgPipeline.from_pretrained(args.ckpt_path, torch_dtype=torch.bfloat16).to('cuda') + pipe.load_lora_weights(args.lora_weights) + + trans_vae = TransparentVAE(pipe.vae, pipe.vae.dtype) + trans_vae.load_state_dict(torch.load(args.trans_vae), strict=False) + + print("all loaded") + + img = generate_img(pipe, trans_vae, args) + + # save image + os.makedirs(args.output_dir, exist_ok=True) + output_path = os.path.join(args.output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + img.save(output_path) diff --git a/demo_t2i.py b/demo_t2i.py new file mode 100644 index 0000000..682f60b --- /dev/null +++ b/demo_t2i.py @@ -0,0 +1,68 @@ +import torch +import argparse +import os +import datetime +from diffusers import FluxPipeline +from lib_layerdiffuse.vae import TransparentVAE +from PIL import Image +import numpy as np + +def generate_img(pipe, trans_vae, args): + + latents = pipe( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.steps, + output_type="latent", + generator=torch.Generator("cuda").manual_seed(args.seed), + guidance_scale=args.guidance, + + ).images + + latents = pipe._unpack_latents(latents, args.height, args.width, pipe.vae_scale_factor) + latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor + + with torch.no_grad(): + original_x, x = trans_vae.decode(latents) + + x = x.clamp(0, 1) + x = x.permute(0, 2, 3, 1) + img = Image.fromarray((x*255).float().cpu().numpy().astype(np.uint8)[0]) + + return img + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--trans_vae", type=str, default="./models/TransparentVAE.pth") + parser.add_argument("--output_dir", type=str, default="./flux-layer-outputs") + parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype") + parser.add_argument("--seed", type=int, default=11111) + parser.add_argument("--steps", type=int, default=50) + parser.add_argument("--guidance", type=float, default=3.5) + parser.add_argument("--prompt", type=str, default="glass bottle, high quality") + parser.add_argument( + "--lora_weights", + type=str, + default="./models/layerlora.safetensors", + ) + parser.add_argument("--width", type=int, default=1024) + parser.add_argument("--height", type=int, default=1024) + args = parser.parse_args() + + pipe = FluxPipeline.from_pretrained(args.ckpt_path, torch_dtype=torch.bfloat16).to('cuda') + pipe.load_lora_weights(args.lora_weights) + + trans_vae = TransparentVAE(pipe.vae, pipe.vae.dtype) + trans_vae.load_state_dict(torch.load(args.trans_vae), strict=False) + trans_vae.to('cuda') + + print("all loaded") + + img = generate_img(pipe, trans_vae, args) + + # save image + os.makedirs(args.output_dir, exist_ok=True) + output_path = os.path.join(args.output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + img.save(output_path) diff --git a/imgs/causal_cut.png b/imgs/causal_cut.png new file mode 100644 index 0000000..fe31191 Binary files /dev/null and b/imgs/causal_cut.png differ diff --git a/imgs/causal_cut_output.png b/imgs/causal_cut_output.png new file mode 100644 index 0000000..62efdb4 Binary files /dev/null and b/imgs/causal_cut_output.png differ diff --git a/imgs/causal_cut_output_vis.png b/imgs/causal_cut_output_vis.png new file mode 100644 index 0000000..2fe8880 Binary files /dev/null and b/imgs/causal_cut_output_vis.png differ diff --git a/imgs/causal_cut_vis.png b/imgs/causal_cut_vis.png new file mode 100644 index 0000000..844b310 Binary files /dev/null and b/imgs/causal_cut_vis.png differ diff --git a/imgs/flux_layer_t2i_examples/test0_10.png b/imgs/flux_layer_t2i_examples/test0_10.png new file mode 100644 index 0000000..75e31f4 Binary files /dev/null and b/imgs/flux_layer_t2i_examples/test0_10.png differ diff --git a/imgs/flux_layer_t2i_examples/test0_10_vis.jpg b/imgs/flux_layer_t2i_examples/test0_10_vis.jpg new file mode 100644 index 0000000..98ae400 Binary files /dev/null and b/imgs/flux_layer_t2i_examples/test0_10_vis.jpg differ diff --git a/imgs/flux_layer_t2i_examples/test10_20.png b/imgs/flux_layer_t2i_examples/test10_20.png new file mode 100644 index 0000000..a6a1e56 Binary files /dev/null and b/imgs/flux_layer_t2i_examples/test10_20.png differ diff --git a/imgs/flux_layer_t2i_examples/test10_20_vis.jpg b/imgs/flux_layer_t2i_examples/test10_20_vis.jpg new file mode 100644 index 0000000..ef697db Binary files /dev/null and b/imgs/flux_layer_t2i_examples/test10_20_vis.jpg differ diff --git a/imgs/flux_layer_t2i_examples/test20_30.png b/imgs/flux_layer_t2i_examples/test20_30.png new file mode 100644 index 0000000..7738712 Binary files /dev/null and b/imgs/flux_layer_t2i_examples/test20_30.png differ diff --git a/imgs/flux_layer_t2i_examples/test20_30_vis.jpg b/imgs/flux_layer_t2i_examples/test20_30_vis.jpg new file mode 100644 index 0000000..aee1116 Binary files /dev/null and b/imgs/flux_layer_t2i_examples/test20_30_vis.jpg differ diff --git a/imgs/flux_layer_t2i_examples/test30_40.png b/imgs/flux_layer_t2i_examples/test30_40.png new file mode 100644 index 0000000..60979dc Binary files /dev/null and b/imgs/flux_layer_t2i_examples/test30_40.png differ diff --git a/imgs/flux_layer_t2i_examples/test30_40_vis.jpg b/imgs/flux_layer_t2i_examples/test30_40_vis.jpg new file mode 100644 index 0000000..d106a0e Binary files /dev/null and b/imgs/flux_layer_t2i_examples/test30_40_vis.jpg differ diff --git a/imgs/flux_layer_t2i_examples/test40_50.png b/imgs/flux_layer_t2i_examples/test40_50.png new file mode 100644 index 0000000..e7929bd Binary files /dev/null and b/imgs/flux_layer_t2i_examples/test40_50.png differ diff --git a/imgs/flux_layer_t2i_examples/test40_50_vis.jpg b/imgs/flux_layer_t2i_examples/test40_50_vis.jpg new file mode 100644 index 0000000..c61ad33 Binary files /dev/null and b/imgs/flux_layer_t2i_examples/test40_50_vis.jpg differ diff --git a/imgs/top_examples/boat.png b/imgs/top_examples/boat.png new file mode 100644 index 0000000..d0be4ee Binary files /dev/null and b/imgs/top_examples/boat.png differ diff --git a/imgs/top_examples/boat_vis.png b/imgs/top_examples/boat_vis.png new file mode 100644 index 0000000..a53854b Binary files /dev/null and b/imgs/top_examples/boat_vis.png differ diff --git a/imgs/top_examples/bottle.png b/imgs/top_examples/bottle.png new file mode 100644 index 0000000..daae911 Binary files /dev/null and b/imgs/top_examples/bottle.png differ diff --git a/imgs/top_examples/bottle_vis.png b/imgs/top_examples/bottle_vis.png new file mode 100644 index 0000000..dcacce3 Binary files /dev/null and b/imgs/top_examples/bottle_vis.png differ diff --git a/imgs/top_examples/cat_doctor.png b/imgs/top_examples/cat_doctor.png new file mode 100644 index 0000000..a312df1 Binary files /dev/null and b/imgs/top_examples/cat_doctor.png differ diff --git a/imgs/top_examples/cat_doctor_vis.png b/imgs/top_examples/cat_doctor_vis.png new file mode 100644 index 0000000..048db89 Binary files /dev/null and b/imgs/top_examples/cat_doctor_vis.png differ diff --git a/imgs/top_examples/dragonball.png b/imgs/top_examples/dragonball.png new file mode 100644 index 0000000..9bd6753 Binary files /dev/null and b/imgs/top_examples/dragonball.png differ diff --git a/imgs/top_examples/dragonball_vis.png b/imgs/top_examples/dragonball_vis.png new file mode 100644 index 0000000..e1cee58 Binary files /dev/null and b/imgs/top_examples/dragonball_vis.png differ diff --git a/imgs/top_examples/dress.png b/imgs/top_examples/dress.png new file mode 100644 index 0000000..38deb57 Binary files /dev/null and b/imgs/top_examples/dress.png differ diff --git a/imgs/top_examples/dress_vis.png b/imgs/top_examples/dress_vis.png new file mode 100644 index 0000000..a404865 Binary files /dev/null and b/imgs/top_examples/dress_vis.png differ diff --git a/imgs/top_examples/half.png b/imgs/top_examples/half.png new file mode 100644 index 0000000..62ed390 Binary files /dev/null and b/imgs/top_examples/half.png differ diff --git a/imgs/top_examples/half_vis.png b/imgs/top_examples/half_vis.png new file mode 100644 index 0000000..bdd1ee1 Binary files /dev/null and b/imgs/top_examples/half_vis.png differ diff --git a/imgs/top_examples/stickers.png b/imgs/top_examples/stickers.png new file mode 100644 index 0000000..5c74148 Binary files /dev/null and b/imgs/top_examples/stickers.png differ diff --git a/imgs/top_examples/stickers_vis.png b/imgs/top_examples/stickers_vis.png new file mode 100644 index 0000000..cdcc38b Binary files /dev/null and b/imgs/top_examples/stickers_vis.png differ diff --git a/lib_layerdiffuse/vae.py b/lib_layerdiffuse/vae.py new file mode 100644 index 0000000..b297bfe --- /dev/null +++ b/lib_layerdiffuse/vae.py @@ -0,0 +1,447 @@ +import torch.nn as nn +import torch +import cv2 +import numpy as np +import safetensors.torch as sf +from accelerate.logging import get_logger +logger = get_logger(__name__, log_level="INFO") + +from tqdm import tqdm +from typing import Optional, Tuple +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution + +import torchvision + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class LatentTransparencyOffsetEncoder(torch.nn.Module): + def __init__(self, latent_c=4, *args, **kwargs): + super().__init__(*args, **kwargs) + self.blocks = torch.nn.Sequential( + torch.nn.Conv2d(4, 32, kernel_size=3, padding=1, stride=1), + nn.SiLU(), + torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1), + nn.SiLU(), + torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2), + nn.SiLU(), + torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1), + nn.SiLU(), + torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2), + nn.SiLU(), + torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1), + nn.SiLU(), + torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), + nn.SiLU(), + torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), + nn.SiLU(), + zero_module(torch.nn.Conv2d(256, latent_c, kernel_size=3, padding=1, stride=1)), + ) + + def __call__(self, x): + return self.blocks(x) + + +# 1024 * 1024 * 3 -> 16 * 16 * 512 -> 1024 * 1024 * 3 +class UNet1024(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), + up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"), + block_out_channels: Tuple[int] = (32, 32, 64, 128, 256, 512, 512), + layers_per_block: int = 2, + mid_block_scale_factor: float = 1, + downsample_padding: int = 1, + downsample_type: str = "conv", + upsample_type: str = "conv", + dropout: float = 0.0, + act_fn: str = "silu", + attention_head_dim: Optional[int] = 8, + norm_num_groups: int = 4, + norm_eps: float = 1e-5, + latent_c: int = 4, + ): + super().__init__() + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + self.latent_conv_in = zero_module(nn.Conv2d(latent_c, block_out_channels[2], kernel_size=1)) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=None, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, + downsample_padding=downsample_padding, + resnet_time_scale_shift="default", + downsample_type=downsample_type, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=None, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], + resnet_groups=norm_num_groups, + attn_groups=None, + add_attention=True, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=None, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, + resnet_time_scale_shift="default", + upsample_type=upsample_type, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def forward(self, x, latent): + sample_latent = self.latent_conv_in(latent) + sample = self.conv_in(x) + emb = None + + down_block_res_samples = (sample,) + for i, downsample_block in enumerate(self.down_blocks): + if i == 3: + sample = sample + sample_latent + + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + down_block_res_samples += res_samples + + sample = self.mid_block(sample, emb) + + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + sample = upsample_block(sample, res_samples, emb) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + return sample + + +def checkerboard(shape): + return np.indices(shape).sum(axis=0) % 2 + + +def build_alpha_pyramid(color, alpha, dk=1.2): + # Written by lvmin at Stanford + # Massive iterative Gaussian filters are mathematically consistent to pyramid. + + pyramid = [] + current_premultiplied_color = color * alpha + current_alpha = alpha + + while True: + pyramid.append((current_premultiplied_color, current_alpha)) + + H, W, C = current_alpha.shape + if min(H, W) == 1: + break + + current_premultiplied_color = cv2.resize(current_premultiplied_color, (int(W / dk), int(H / dk)), interpolation=cv2.INTER_AREA) + current_alpha = cv2.resize(current_alpha, (int(W / dk), int(H / dk)), interpolation=cv2.INTER_AREA)[:, :, None] + return pyramid[::-1] + + +def pad_rgb(np_rgba_hwc_uint8): + # Written by lvmin at Stanford + # Massive iterative Gaussian filters are mathematically consistent to pyramid. + + np_rgba_hwc = np_rgba_hwc_uint8.astype(np.float32) #/ 255.0 + pyramid = build_alpha_pyramid(color=np_rgba_hwc[..., :3], alpha=np_rgba_hwc[..., 3:]) + + top_c, top_a = pyramid[0] + fg = np.sum(top_c, axis=(0, 1), keepdims=True) / np.sum(top_a, axis=(0, 1), keepdims=True).clip(1e-8, 1e32) + + for layer_c, layer_a in pyramid: + layer_h, layer_w, _ = layer_c.shape + fg = cv2.resize(fg, (layer_w, layer_h), interpolation=cv2.INTER_LINEAR) + fg = layer_c + fg * (1.0 - layer_a) + + return fg + + +def dist_sample_deterministic(dist: DiagonalGaussianDistribution, perturbation: torch.Tensor): + # Modified from diffusers.models.autoencoders.vae.DiagonalGaussianDistribution.sample() + x = dist.mean + dist.std * perturbation.to(dist.std) + return x + +class TransparentVAE(torch.nn.Module): + def __init__(self, sd_vae, dtype=torch.float16, encoder_file=None, decoder_file=None, alpha=300.0, latent_c=16, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dtype = dtype + + self.sd_vae = sd_vae + self.sd_vae.to(dtype=self.dtype) + self.sd_vae.requires_grad_(False) + + self.encoder = LatentTransparencyOffsetEncoder(latent_c=latent_c) + if encoder_file is not None: + temp = sf.load_file(encoder_file) + # del temp['blocks.16.weight'] + # del temp['blocks.16.bias'] + self.encoder.load_state_dict(temp, strict=True) + del temp + self.encoder.to(dtype=self.dtype) + self.alpha = alpha + + self.decoder = UNet1024(in_channels=3, out_channels=4, latent_c=latent_c) + if decoder_file is not None: + temp = sf.load_file(decoder_file) + # del temp['latent_conv_in.weight'] + # del temp['latent_conv_in.bias'] + self.decoder.load_state_dict(temp, strict=True) + del temp + self.decoder.to(dtype=self.dtype) + self.latent_c = latent_c + + + def sd_decode(self, latent): + return self.sd_vae.decode(latent) + + def decode(self, latent, aug=True): + origin_pixel = self.sd_vae.decode(latent).sample + origin_pixel = (origin_pixel * 0.5 + 0.5) + if not aug: + y = self.decoder(origin_pixel.to(self.dtype), latent.to(self.dtype)) + return origin_pixel, y + list_y = [] + for i in range(int(latent.shape[0])): + y = self.estimate_augmented(origin_pixel[i:i + 1].to(self.dtype), latent[i:i + 1].to(self.dtype)) + list_y.append(y) + y = torch.concat(list_y, dim=0) + return origin_pixel, y + + def encode(self, img_rgba, img_rgb, padded_img_rgb, use_offset=True): + a_bchw_01 = img_rgba[:, 3:, :, :] + vae_feed = img_rgb.to(device=self.sd_vae.device, dtype=self.sd_vae.dtype) + latent_dist = self.sd_vae.encode(vae_feed).latent_dist + offset_feed = torch.cat([padded_img_rgb, a_bchw_01], dim=1).to(device=self.sd_vae.device, dtype=self.dtype) + offset = self.encoder(offset_feed) * self.alpha + if use_offset: + latent = dist_sample_deterministic(dist=latent_dist, perturbation=offset) + latent = self.sd_vae.config.scaling_factor * (latent - self.sd_vae.config.shift_factor) + else: + latent = latent_dist.sample() + latent = self.sd_vae.config.scaling_factor * (latent - self.sd_vae.config.shift_factor) + return latent + + def forward(self, img_rgba, img_rgb, padded_img_rgb, use_offset=True): + return self.decode(self.encode(img_rgba, img_rgb, padded_img_rgb, use_offset)) + + @property + def device(self): + return next(self.parameters()).device + + @torch.no_grad() + def estimate_augmented(self, pixel, latent): + args = [ + [False, 0], [False, 1], [False, 2], [False, 3], [True, 0], [True, 1], [True, 2], [True, 3], + ] + + result = [] + + for flip, rok in tqdm(args): + feed_pixel = pixel.clone() + feed_latent = latent.clone() + + if flip: + feed_pixel = torch.flip(feed_pixel, dims=(3,)) + feed_latent = torch.flip(feed_latent, dims=(3,)) + + feed_pixel = torch.rot90(feed_pixel, k=rok, dims=(2, 3)) + feed_latent = torch.rot90(feed_latent, k=rok, dims=(2, 3)) + + eps = self.decoder(feed_pixel, feed_latent).clip(0, 1) + eps = torch.rot90(eps, k=-rok, dims=(2, 3)) + + if flip: + eps = torch.flip(eps, dims=(3,)) + + result += [eps] + + result = torch.stack(result, dim=0) + median = torch.median(result, dim=0).values + return median + + + +class TransparentVAEDecoder(torch.nn.Module): + def __init__(self, filename, dtype=torch.float16, *args, **kwargs): + super().__init__(*args, **kwargs) + sd = sf.load_file(filename) + model = UNet1024(in_channels=3, out_channels=4) + model.load_state_dict(sd, strict=True) + model.to(dtype=dtype) + model.eval() + self.model = model + self.dtype = dtype + return + + @torch.no_grad() + def estimate_single_pass(self, pixel, latent): + y = self.model(pixel, latent) + return y + + @torch.no_grad() + def estimate_augmented(self, pixel, latent): + args = [ + [False, 0], [False, 1], [False, 2], [False, 3], [True, 0], [True, 1], [True, 2], [True, 3], + ] + + result = [] + + for flip, rok in tqdm(args): + feed_pixel = pixel.clone() + feed_latent = latent.clone() + + if flip: + feed_pixel = torch.flip(feed_pixel, dims=(3,)) + feed_latent = torch.flip(feed_latent, dims=(3,)) + + feed_pixel = torch.rot90(feed_pixel, k=rok, dims=(2, 3)) + feed_latent = torch.rot90(feed_latent, k=rok, dims=(2, 3)) + + eps = self.estimate_single_pass(feed_pixel, feed_latent).clip(0, 1) + eps = torch.rot90(eps, k=-rok, dims=(2, 3)) + + if flip: + eps = torch.flip(eps, dims=(3,)) + + result += [eps] + + result = torch.stack(result, dim=0) + median = torch.median(result, dim=0).values + return median + + @torch.no_grad() + def forward(self, sd_vae, latent): + pixel = sd_vae.decode(latent).sample + pixel = (pixel * 0.5 + 0.5).clip(0, 1).to(self.dtype) + latent = latent.to(self.dtype) + result_list = [] + vis_list = [] + + for i in range(int(latent.shape[0])): + y = self.estimate_augmented(pixel[i:i + 1], latent[i:i + 1]) + + y = y.clip(0, 1).movedim(1, -1) + alpha = y[..., :1] + fg = y[..., 1:] + + B, H, W, C = fg.shape + cb = checkerboard(shape=(H // 64, W // 64)) + cb = cv2.resize(cb, (W, H), interpolation=cv2.INTER_NEAREST) + cb = (0.5 + (cb - 0.5) * 0.1)[None, ..., None] + cb = torch.from_numpy(cb).to(fg) + + vis = (fg * alpha + cb * (1 - alpha))[0] + vis = (vis * 255.0).detach().cpu().float().numpy().clip(0, 255).astype(np.uint8) + vis_list.append(vis) + + png = torch.cat([fg, alpha], dim=3)[0] + png = (png * 255.0).detach().cpu().float().numpy().clip(0, 255).astype(np.uint8) + result_list.append(png) + + return result_list, vis_list + + +class TransparentVAEEncoder(torch.nn.Module): + def __init__(self, filename, dtype=torch.float16, alpha=300.0, *args, **kwargs): + super().__init__(*args, **kwargs) + sd = sf.load_file(filename) + self.dtype = dtype + + model = LatentTransparencyOffsetEncoder() + model.load_state_dict(sd, strict=True) + model.to(dtype=self.dtype) + model.eval() + + self.model = model + + # similar to LoRA's alpha to avoid initial zero-initialized outputs being too small + self.alpha = alpha + return + + @torch.no_grad() + def forward(self, sd_vae, list_of_np_rgba_hwc_uint8, use_offset=True): + list_of_np_rgb_padded = [pad_rgb(x) for x in list_of_np_rgba_hwc_uint8] + rgb_padded_bchw_01 = torch.from_numpy(np.stack(list_of_np_rgb_padded, axis=0)).float().movedim(-1, 1) + rgba_bchw_01 = torch.from_numpy(np.stack(list_of_np_rgba_hwc_uint8, axis=0)).float().movedim(-1, 1) / 255.0 + rgb_bchw_01 = rgba_bchw_01[:, :3, :, :] + a_bchw_01 = rgba_bchw_01[:, 3:, :, :] + vae_feed = (rgb_bchw_01 * 2.0 - 1.0) * a_bchw_01 + vae_feed = vae_feed.to(device=sd_vae.device, dtype=sd_vae.dtype) + latent_dist = sd_vae.encode(vae_feed).latent_dist + offset_feed = torch.cat([a_bchw_01, rgb_padded_bchw_01], dim=1).to(device=sd_vae.device, dtype=self.dtype) + offset = self.model(offset_feed) * self.alpha + if use_offset: + latent = dist_sample_deterministic(dist=latent_dist, perturbation=offset) + else: + latent = latent_dist.sample() + return latent diff --git a/pipeline_flux_img2img.py b/pipeline_flux_img2img.py new file mode 100644 index 0000000..4af204d --- /dev/null +++ b/pipeline_flux_img2img.py @@ -0,0 +1,860 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.transformers import FluxTransformer2DModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import FluxImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> device = "cuda" + >>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + + >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" + + >>> images = pipe( + ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0 + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + # if latents is not None: + # return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + if latents is not None: + image_latents = latents.to(device=device, dtype=dtype) + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + + latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a520018 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,345 @@ +absl-py==2.0.0 +accelerate==1.1.1 +aiohttp==3.8.5 +aiosignal==1.3.1 +annotated-types==0.5.0 +antlr4-python3-runtime==4.9.3 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +asttokens==2.4.0 +astunparse==1.6.3 +async-timeout==4.0.3 +asyncio==3.4.3 +attrs==23.1.0 +audioread==3.0.1 +av==13.1.0 +backcall==0.2.0 +beautifulsoup4==4.12.2 +bitsandbytes==0.41.1 +bleach==6.0.0 +blis==0.7.11 +boto3==1.28.78 +botocore==1.31.78 +braceexpand==0.1.7 +build==1.2.2.post1 +CacheControl==0.14.1 +cachetools==5.3.1 +catalogue==2.0.10 +certifi==2023.7.22 +cffi==1.16.0 +charset-normalizer==3.2.0 +clean-fid==0.1.35 +cleo==2.1.0 +click==8.1.6 +clip==1.0 +cloudpathlib==0.15.1 +cloudpickle==2.2.1 +cmake==3.27.6 +coloredlogs==15.0.1 +comm==0.1.4 +confection==0.1.3 +contourpy==1.1.1 +cos-python-sdk-v5==1.9.26 +coscmd==1.8.6.33 +cpm-kernels==1.0.11 +crashtest==0.4.1 +crcmod==1.7 +cryptography==44.0.0 +cubinlinker==0.3.0+2.gce0680b +cuda-python==12.2.0rc5+5.g84845d1 +cudf==23.8.0 +cugraph==23.8.0 +cugraph-dgl==23.8.0 +cugraph-service-client==23.8.0 +cugraph-service-server==23.8.0 +cuml==23.8.0 +cupy-cuda12x==12.1.0 +cycler==0.12.1 +cymem==2.0.8 +Cython==3.0.3 +dask==2023.7.1 +dask-cuda==23.8.0 +dask-cudf==23.8.0 +datasets==2.14.6 +DateTime==5.2 +debugpy==1.8.0 +decorator==5.1.1 +deepspeed==0.12.2 +defusedxml==0.7.1 +diffusers==0.31.0 +dill==0.3.7 +distlib==0.3.9 +distributed==2023.7.1 +dm-tree==0.1.8 +dulwich==0.21.7 +einops==0.7.0 +en-core-web-sm==3.7.0 +exceptiongroup==1.1.3 +execnet==2.0.2 +executing==2.0.0 +expecttest==0.1.3 +fairscale==0.4.13 +fastChat==0.1.1 +fastjsonschema==2.18.1 +fastrlock==0.8.1 +filelock==3.12.4 +flash-attn==2.6.1 +flatbuffers==24.3.25 +fonttools==4.43.1 +frozenlist==1.4.0 +fsspec==2023.6.0 +ftfy==6.1.1 +gast==0.5.4 +google-auth==2.23.2 +google-auth-oauthlib==0.4.6 +graphsurgeon==0.4.6 +grpcio==1.59.0 +hjson==3.1.0 +huggingface-hub==0.23.4 +humanfriendly==10.0 +hypothesis==5.35.1 +idna==3.4 +image-reward==1.5 +imageio==2.34.2 +imagesize==1.4.1 +importlib-metadata==6.8.0 +iniconfig==2.0.0 +installer==0.7.0 +intel-openmp==2021.4.0 +iopath==0.1.10 +ipdb==0.13.13 +ipykernel==6.25.2 +ipython==8.16.1 +ipython-genutils==0.2.0 +jaraco.classes==3.4.0 +jedi==0.19.1 +jeepney==0.8.0 +Jinja2==3.1.2 +jmespath==1.0.1 +joblib==1.3.2 +json5==0.9.14 +jsonlines==4.0.0 +jsonschema==4.19.1 +jsonschema-specifications==2023.7.1 +jupyter_client==8.3.1 +jupyter_core==5.3.2 +jupyter-tensorboard==0.2.0 +jupyterlab==2.3.2 +jupyterlab-pygments==0.2.2 +jupyterlab-server==1.2.0 +jupytext==1.15.2 +keyring==24.3.1 +kiwisolver==1.4.5 +langcodes==3.3.0 +lazy_loader==0.4 +librosa==0.9.2 +llvmlite==0.40.1 +locket==1.0.0 +Markdown==3.4.4 +markdown-it-py==3.0.0 +MarkupSafe==2.1.3 +matplotlib==3.8.0 +matplotlib-inline==0.1.6 +mdit-py-plugins==0.4.0 +mdurl==0.1.2 +mistune==3.0.2 +mkl==2021.1.1 +mkl-devel==2021.1.1 +mkl-include==2021.1.1 +mock==5.1.0 +more-itertools==10.5.0 +mpmath==1.3.0 +msgpack==1.0.5 +multidict==6.0.4 +multiprocess==0.70.15 +murmurhash==1.0.10 +nbclient==0.8.0 +nbconvert==7.9.2 +nbformat==5.9.2 +nest-asyncio==1.5.8 +networkx==3.3 +ninja==1.11.1.1 +notebook==6.4.10 +numba==0.57.1+1.g5fba9aa8f +numpy==1.26.4 +nvfuser==0.0.20+gitunknown +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-dali-cuda120==1.30.0 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.1.105 +nvidia-nvtx-cu12==12.1.105 +nvidia-pyindex==1.0.9 +nvtx==0.2.5 +oauthlib==3.2.2 +omegaconf==2.3.0 +onnx==1.14.0 +onnx-graphsurgeon==0.3.27 +onnxruntime-gpu==1.15.1 +open-clip-torch==2.23.0 +opencv==4.7.0 +opencv-contrib-python==4.10.0.84 +opencv-python==4.8.0.74 +optimum==1.19.2 +packaging==23.1 +pandas==1.5.3 +pandocfilters==1.5.0 +parso==0.8.3 +partd==1.4.0 +pathy==0.10.2 +peft==0.13.2 +pexpect==4.8.0 +pickleshare==0.7.5 +pillow==10.3.0 +pillow-heif==0.13.1 +pip==23.2.1 +pkginfo==1.11.2 +platformdirs==3.11.0 +pluggy==1.3.0 +ply==3.11 +poetry-core==1.9.1 +poetry-plugin-export==1.8.0 +polygraphy==0.49.0 +pooch==1.7.0 +portalocker==2.10.0 +preshed==3.0.9 +prettytable==3.9.0 +prometheus-client==0.17.1 +prompt-toolkit==3.0.39 +protobuf==4.24.4 +psutil==5.9.4 +ptxcompiler==0.8.1+1.g2cb1b35 +ptyprocess==0.7.0 +pure-eval==0.2.2 +py-cpuinfo==9.0.0 +pyarrow==11.0.0 +pyasn1==0.5.0 +pyasn1-modules==0.3.0 +pybind11==2.11.1 +pybind11-global==2.11.1 +pycocotools==2.0+nv0.7.3 +pycparser==2.21 +pycryptodome==3.19.0 +pydantic==2.4.2 +pydantic_core==2.10.1 +Pygments==2.16.1 +pylibcugraph==23.8.0 +pylibcugraphops==23.8.0 +pylibraft==23.8.0 +pynvml==11.4.1 +pyparsing==3.1.1 +pyproject_hooks==1.2.0 +pytest==7.4.2 +pytest-flakefinder==1.1.0 +pytest-rerunfailures==12.0 +pytest-shard==0.1.2 +pytest-xdist==3.3.1 +python-dateutil==2.8.2 +python-hostlist==1.23.0 +pytorch-quantization==2.1.2 +pytz==2023.3 +PyYAML==6.0.1 +pyzmq==25.1.1 +qwen-vl-utils==0.0.8 +raft-dask==23.8.0 +RapidFuzz==3.10.1 +referencing==0.30.2 +regex==2023.10.3 +requests==2.31.0 +requests-oauthlib==1.3.1 +requests-toolbelt==1.0.0 +resampy==0.4.2 +rich==13.9.4 +rmm==23.8.0 +rpds-py==0.10.4 +rsa==4.9 +s3transfer==0.7.0 +safetensors==0.4.3 +scikit-image==0.24.0 +scikit-learn==1.2.0 +scipy==1.11.1 +seaborn==0.13.0 +SecretStorage==3.3.3 +Send2Trash==1.8.2 +sentencepiece==0.1.99 +setuptools==68.2.2 +shellingham==1.5.4 +six==1.16.0 +smart-open==6.4.0 +sortedcontainers==2.4.0 +soundfile==0.12.1 +soupsieve==2.5 +spacy==3.7.1 +spacy-legacy==3.0.12 +spacy-loggers==1.0.5 +sphinx-glpi-theme==0.3 +srsly==2.4.8 +stable-fast==0.0.13.post4 +stack-data==0.6.3 +SwissArmyTransformer==0.4.8 +sympy==1.12 +tabulate==0.9.0 +tbb==2021.10.0 +tblib==2.0.0 +tensorboard==2.9.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +tensorboardX==2.6.2.2 +tensorrt==8.6.1 +terminado==0.17.1 +thinc==8.2.1 +threadpoolctl==3.2.0 +thriftpy2==0.4.16 +tifffile==2024.7.2 +timm==0.6.13 +tinycss2==1.2.1 +tokenizers==0.20.3 +toml==0.10.2 +tomli==2.0.1 +tomlkit==0.13.2 +toolz==0.12.0 +torch==2.3.0+cu121 +torch-tensorrt==0.0.0 +torchdata==0.7.0a0 +torchtext==0.16.0a0 +torchvision==0.18.0 +tornado==6.3.3 +tqdm==4.66.1 +traitlets==5.9.0 +transformers==4.46.3 +treelite==3.2.0 +treelite-runtime==3.2.0 +triton==2.3.0 +trove-classifiers==2024.10.21.16 +typer==0.9.0 +types-dataclasses==0.6.6 +typing_extensions==4.9.0 +ucx-py==0.33.0 +uff==0.6.9 +urllib3==1.26.16 +virtualenv==20.28.0 +voluptuous==0.14.2 +wasabi==1.1.2 +wcwidth==0.2.8 +weasel==0.3.2 +webdataset==0.2.73 +webencodings==0.5.1 +Werkzeug==3.0.0 +wheel==0.41.2 +xdoctest==1.0.2 +xformers==0.0.26.post1 +xgboost==1.7.5 +xhsBaseCV==0.0.1 +xmltodict==0.13.0 +xxhash==3.4.1 +yarl==1.9.2 +zict==3.0.0 +zipp==3.16.2 +zope.interface==6.1