Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
innerway-xq committed Dec 11, 2024
0 parents commit 07cd568
Show file tree
Hide file tree
Showing 36 changed files with 1,870 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
imgs filter=lfs diff=lfs merge=lfs -text
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__pycache__
models
flux-layer-outputs
*captions.json
56 changes: 56 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

# LayerDiffuse-Flux
This repo is a Flux version implementation of LayerDiffuse ([LayerDiffuse](https://github.com/lllyasviel/LayerDiffuse)).

We train a **new transparent vae** to adapt to Flux and train a **lora** to finetune Flux to generate transparent images.

|![](./imgs/top_examples/bottle_vis.png)|![](./imgs/top_examples/boat_vis.png)|![](./imgs/top_examples/cat_doctor_vis.png)|![](./imgs/top_examples/dragonball_vis.png)|![](./imgs/top_examples/dress_vis.png)|![](./imgs/top_examples/half_vis.png)|![](./imgs/top_examples/stickers_vis.png)|
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
## Usage

+ Clone this repository.
```shell
git clone ...
cd LayerDiffuse-Flux
```
+ download weights
``` shell

```

## Flux Transparent T2I
### demo
```shell
python demo_t2i.py --ckpt_path /your/path/to/FLUX.1 dev
```
### examples

| Examples: top to bottom: flux, flux-layer(ours), sdxl, sdxl-layer([LayerDiffuse](https://github.com/lllyasviel/LayerDiffuse)) |
|------------------------------------|
|![](./imgs/flux_layer_t2i_examples/test0_10_vis.jpg)|
|![](./imgs/flux_layer_t2i_examples/test10_20_vis.jpg)|
|![](./imgs/flux_layer_t2i_examples/test20_30_vis.jpg)|
|![](./imgs/flux_layer_t2i_examples/test30_40_vis.jpg)|
|![](./imgs/flux_layer_t2i_examples/test40_50_vis.jpg)|



## Flux Transparent I2I
```shell
python demo_i2i.py --ckpt_path /your/path/to/FLUX.1 dev --image "./imgs/causal_cut.png"
```
Prompt: "a handsome man with curly hair, high quality"

Strength: 0.9

| Input (Transparent image) | Output (Transparent image) |
|------------------------------------|--------------------------------------------|
| ![img](imgs/causal_cut_vis.png) | ![img](imgs/causal_cut_output_vis.png) |

## Acknowledgements
Thanks lllyasviel for their great work [LayerDiffuse](https://github.com/lllyasviel/LayerDiffuse)

## Contact
If you have any questions about the code, please do not hesitate to contact us!

Email: [email protected], [email protected]
89 changes: 89 additions & 0 deletions demo_i2i.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
import argparse
import os
import datetime
from pipeline_flux_img2img import FluxImg2ImgPipeline
from lib_layerdiffuse.vae import TransparentVAE, pad_rgb
from PIL import Image
import numpy as np
from torchvision import transforms
from safetensors.torch import load_file
from PIL import Image, ImageDraw, ImageFont


def generate_img(pipe, trans_vae, args):
original_image = (transforms.ToTensor()(Image.open(args.image))).unsqueeze(0)
padding_feed = [x for x in original_image.movedim(1, -1).float().cpu().numpy()]
list_of_np_rgb_padded = [pad_rgb(x) for x in padding_feed]
rgb_padded_bchw_01 = torch.from_numpy(np.stack(list_of_np_rgb_padded, axis=0)).float().movedim(-1, 1).to(original_image.device)
original_image_feed = original_image.clone()
original_image_feed[:, :3, :, :] = original_image_feed[:, :3, :, :] * 2.0 - 1.0
original_image_rgb = original_image_feed[:, :3, :, :] * original_image_feed[:, 3, :, :]

original_image_feed = original_image_feed.to("cuda")
original_image_rgb = original_image_rgb.to("cuda")
rgb_padded_bchw_01 = rgb_padded_bchw_01.to("cuda")
trans_vae.to(torch.device('cuda'))
rng = torch.Generator("cuda").manual_seed(args.seed)

initial_latent = trans_vae.encode(original_image_feed, original_image_rgb, rgb_padded_bchw_01, use_offset=True)

latents = pipe(
latents=initial_latent,
image=original_image,
prompt=args.prompt,
height=args.height,
width=args.width,
num_inference_steps=args.steps,
output_type="latent",
generator=rng,
guidance_scale=args.guidance,
strength=args.strength,
).images

latents = pipe._unpack_latents(latents, args.height, args.width, pipe.vae_scale_factor)
latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor

with torch.no_grad():
original_x, x = trans_vae.decode(latents)

x = x.clamp(0, 1)
x = x.permute(0, 2, 3, 1)
img = Image.fromarray((x*255).float().cpu().numpy().astype(np.uint8)[0])
return img

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--trans_vae", type=str, default="./models/TransparentVAE.pth")
parser.add_argument("--output_dir", type=str, default="./flux-layer-outputs")
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
parser.add_argument("--seed", type=int, default=43)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--guidance", type=float, default=7.0)
parser.add_argument("--strength", type=float, default=0.8)
parser.add_argument("--prompt", type=str, default="a handsome man with curly hair, high quality")
parser.add_argument(
"--lora_weights",
type=str,
default="./models/layerlora.safetensors",
)
parser.add_argument("--width", type=int, default=1024)
parser.add_argument("--height", type=int, default=1024)
parser.add_argument("--image", type=str, default="./imgs/causal_cut.png")
args = parser.parse_args()

pipe = FluxImg2ImgPipeline.from_pretrained(args.ckpt_path, torch_dtype=torch.bfloat16).to('cuda')
pipe.load_lora_weights(args.lora_weights)

trans_vae = TransparentVAE(pipe.vae, pipe.vae.dtype)
trans_vae.load_state_dict(torch.load(args.trans_vae), strict=False)

print("all loaded")

img = generate_img(pipe, trans_vae, args)

# save image
os.makedirs(args.output_dir, exist_ok=True)
output_path = os.path.join(args.output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
img.save(output_path)
68 changes: 68 additions & 0 deletions demo_t2i.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
import argparse
import os
import datetime
from diffusers import FluxPipeline
from lib_layerdiffuse.vae import TransparentVAE
from PIL import Image
import numpy as np

def generate_img(pipe, trans_vae, args):

latents = pipe(
prompt=args.prompt,
height=args.height,
width=args.width,
num_inference_steps=args.steps,
output_type="latent",
generator=torch.Generator("cuda").manual_seed(args.seed),
guidance_scale=args.guidance,

).images

latents = pipe._unpack_latents(latents, args.height, args.width, pipe.vae_scale_factor)
latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor

with torch.no_grad():
original_x, x = trans_vae.decode(latents)

x = x.clamp(0, 1)
x = x.permute(0, 2, 3, 1)
img = Image.fromarray((x*255).float().cpu().numpy().astype(np.uint8)[0])

return img

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--trans_vae", type=str, default="./models/TransparentVAE.pth")
parser.add_argument("--output_dir", type=str, default="./flux-layer-outputs")
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
parser.add_argument("--seed", type=int, default=11111)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--guidance", type=float, default=3.5)
parser.add_argument("--prompt", type=str, default="glass bottle, high quality")
parser.add_argument(
"--lora_weights",
type=str,
default="./models/layerlora.safetensors",
)
parser.add_argument("--width", type=int, default=1024)
parser.add_argument("--height", type=int, default=1024)
args = parser.parse_args()

pipe = FluxPipeline.from_pretrained(args.ckpt_path, torch_dtype=torch.bfloat16).to('cuda')
pipe.load_lora_weights(args.lora_weights)

trans_vae = TransparentVAE(pipe.vae, pipe.vae.dtype)
trans_vae.load_state_dict(torch.load(args.trans_vae), strict=False)
trans_vae.to('cuda')

print("all loaded")

img = generate_img(pipe, trans_vae, args)

# save image
os.makedirs(args.output_dir, exist_ok=True)
output_path = os.path.join(args.output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
img.save(output_path)
Binary file added imgs/causal_cut.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/causal_cut_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/causal_cut_output_vis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/causal_cut_vis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/flux_layer_t2i_examples/test0_10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/flux_layer_t2i_examples/test0_10_vis.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/flux_layer_t2i_examples/test10_20.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/flux_layer_t2i_examples/test10_20_vis.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/flux_layer_t2i_examples/test20_30.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/flux_layer_t2i_examples/test20_30_vis.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/flux_layer_t2i_examples/test30_40.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/flux_layer_t2i_examples/test30_40_vis.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/flux_layer_t2i_examples/test40_50.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/flux_layer_t2i_examples/test40_50_vis.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/top_examples/boat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/top_examples/boat_vis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/top_examples/bottle.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/top_examples/bottle_vis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/top_examples/cat_doctor.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/top_examples/cat_doctor_vis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/top_examples/dragonball.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/top_examples/dragonball_vis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/top_examples/dress.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/top_examples/dress_vis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/top_examples/half.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/top_examples/half_vis.png
Binary file added imgs/top_examples/stickers.png
Binary file added imgs/top_examples/stickers_vis.png
Loading

0 comments on commit 07cd568

Please sign in to comment.