-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 07cd568
Showing
36 changed files
with
1,870 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
imgs filter=lfs diff=lfs merge=lfs -text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
__pycache__ | ||
models | ||
flux-layer-outputs | ||
*captions.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
|
||
# LayerDiffuse-Flux | ||
This repo is a Flux version implementation of LayerDiffuse ([LayerDiffuse](https://github.com/lllyasviel/LayerDiffuse)). | ||
|
||
We train a **new transparent vae** to adapt to Flux and train a **lora** to finetune Flux to generate transparent images. | ||
|
||
|![](./imgs/top_examples/bottle_vis.png)|![](./imgs/top_examples/boat_vis.png)|![](./imgs/top_examples/cat_doctor_vis.png)|![](./imgs/top_examples/dragonball_vis.png)|![](./imgs/top_examples/dress_vis.png)|![](./imgs/top_examples/half_vis.png)|![](./imgs/top_examples/stickers_vis.png)| | ||
|:-:|:-:|:-:|:-:|:-:|:-:|:-:| | ||
## Usage | ||
|
||
+ Clone this repository. | ||
```shell | ||
git clone ... | ||
cd LayerDiffuse-Flux | ||
``` | ||
+ download weights | ||
``` shell | ||
|
||
``` | ||
|
||
## Flux Transparent T2I | ||
### demo | ||
```shell | ||
python demo_t2i.py --ckpt_path /your/path/to/FLUX.1 dev | ||
``` | ||
### examples | ||
|
||
| Examples: top to bottom: flux, flux-layer(ours), sdxl, sdxl-layer([LayerDiffuse](https://github.com/lllyasviel/LayerDiffuse)) | | ||
|------------------------------------| | ||
|![](./imgs/flux_layer_t2i_examples/test0_10_vis.jpg)| | ||
|![](./imgs/flux_layer_t2i_examples/test10_20_vis.jpg)| | ||
|![](./imgs/flux_layer_t2i_examples/test20_30_vis.jpg)| | ||
|![](./imgs/flux_layer_t2i_examples/test30_40_vis.jpg)| | ||
|![](./imgs/flux_layer_t2i_examples/test40_50_vis.jpg)| | ||
|
||
|
||
|
||
## Flux Transparent I2I | ||
```shell | ||
python demo_i2i.py --ckpt_path /your/path/to/FLUX.1 dev --image "./imgs/causal_cut.png" | ||
``` | ||
Prompt: "a handsome man with curly hair, high quality" | ||
|
||
Strength: 0.9 | ||
|
||
| Input (Transparent image) | Output (Transparent image) | | ||
|------------------------------------|--------------------------------------------| | ||
| ![img](imgs/causal_cut_vis.png) | ![img](imgs/causal_cut_output_vis.png) | | ||
|
||
## Acknowledgements | ||
Thanks lllyasviel for their great work [LayerDiffuse](https://github.com/lllyasviel/LayerDiffuse) | ||
|
||
## Contact | ||
If you have any questions about the code, please do not hesitate to contact us! | ||
|
||
Email: [email protected], [email protected] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import torch | ||
import argparse | ||
import os | ||
import datetime | ||
from pipeline_flux_img2img import FluxImg2ImgPipeline | ||
from lib_layerdiffuse.vae import TransparentVAE, pad_rgb | ||
from PIL import Image | ||
import numpy as np | ||
from torchvision import transforms | ||
from safetensors.torch import load_file | ||
from PIL import Image, ImageDraw, ImageFont | ||
|
||
|
||
def generate_img(pipe, trans_vae, args): | ||
original_image = (transforms.ToTensor()(Image.open(args.image))).unsqueeze(0) | ||
padding_feed = [x for x in original_image.movedim(1, -1).float().cpu().numpy()] | ||
list_of_np_rgb_padded = [pad_rgb(x) for x in padding_feed] | ||
rgb_padded_bchw_01 = torch.from_numpy(np.stack(list_of_np_rgb_padded, axis=0)).float().movedim(-1, 1).to(original_image.device) | ||
original_image_feed = original_image.clone() | ||
original_image_feed[:, :3, :, :] = original_image_feed[:, :3, :, :] * 2.0 - 1.0 | ||
original_image_rgb = original_image_feed[:, :3, :, :] * original_image_feed[:, 3, :, :] | ||
|
||
original_image_feed = original_image_feed.to("cuda") | ||
original_image_rgb = original_image_rgb.to("cuda") | ||
rgb_padded_bchw_01 = rgb_padded_bchw_01.to("cuda") | ||
trans_vae.to(torch.device('cuda')) | ||
rng = torch.Generator("cuda").manual_seed(args.seed) | ||
|
||
initial_latent = trans_vae.encode(original_image_feed, original_image_rgb, rgb_padded_bchw_01, use_offset=True) | ||
|
||
latents = pipe( | ||
latents=initial_latent, | ||
image=original_image, | ||
prompt=args.prompt, | ||
height=args.height, | ||
width=args.width, | ||
num_inference_steps=args.steps, | ||
output_type="latent", | ||
generator=rng, | ||
guidance_scale=args.guidance, | ||
strength=args.strength, | ||
).images | ||
|
||
latents = pipe._unpack_latents(latents, args.height, args.width, pipe.vae_scale_factor) | ||
latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor | ||
|
||
with torch.no_grad(): | ||
original_x, x = trans_vae.decode(latents) | ||
|
||
x = x.clamp(0, 1) | ||
x = x.permute(0, 2, 3, 1) | ||
img = Image.fromarray((x*255).float().cpu().numpy().astype(np.uint8)[0]) | ||
return img | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--ckpt_path", type=str, required=True) | ||
parser.add_argument("--trans_vae", type=str, default="./models/TransparentVAE.pth") | ||
parser.add_argument("--output_dir", type=str, default="./flux-layer-outputs") | ||
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype") | ||
parser.add_argument("--seed", type=int, default=43) | ||
parser.add_argument("--steps", type=int, default=50) | ||
parser.add_argument("--guidance", type=float, default=7.0) | ||
parser.add_argument("--strength", type=float, default=0.8) | ||
parser.add_argument("--prompt", type=str, default="a handsome man with curly hair, high quality") | ||
parser.add_argument( | ||
"--lora_weights", | ||
type=str, | ||
default="./models/layerlora.safetensors", | ||
) | ||
parser.add_argument("--width", type=int, default=1024) | ||
parser.add_argument("--height", type=int, default=1024) | ||
parser.add_argument("--image", type=str, default="./imgs/causal_cut.png") | ||
args = parser.parse_args() | ||
|
||
pipe = FluxImg2ImgPipeline.from_pretrained(args.ckpt_path, torch_dtype=torch.bfloat16).to('cuda') | ||
pipe.load_lora_weights(args.lora_weights) | ||
|
||
trans_vae = TransparentVAE(pipe.vae, pipe.vae.dtype) | ||
trans_vae.load_state_dict(torch.load(args.trans_vae), strict=False) | ||
|
||
print("all loaded") | ||
|
||
img = generate_img(pipe, trans_vae, args) | ||
|
||
# save image | ||
os.makedirs(args.output_dir, exist_ok=True) | ||
output_path = os.path.join(args.output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") | ||
img.save(output_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch | ||
import argparse | ||
import os | ||
import datetime | ||
from diffusers import FluxPipeline | ||
from lib_layerdiffuse.vae import TransparentVAE | ||
from PIL import Image | ||
import numpy as np | ||
|
||
def generate_img(pipe, trans_vae, args): | ||
|
||
latents = pipe( | ||
prompt=args.prompt, | ||
height=args.height, | ||
width=args.width, | ||
num_inference_steps=args.steps, | ||
output_type="latent", | ||
generator=torch.Generator("cuda").manual_seed(args.seed), | ||
guidance_scale=args.guidance, | ||
|
||
).images | ||
|
||
latents = pipe._unpack_latents(latents, args.height, args.width, pipe.vae_scale_factor) | ||
latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor | ||
|
||
with torch.no_grad(): | ||
original_x, x = trans_vae.decode(latents) | ||
|
||
x = x.clamp(0, 1) | ||
x = x.permute(0, 2, 3, 1) | ||
img = Image.fromarray((x*255).float().cpu().numpy().astype(np.uint8)[0]) | ||
|
||
return img | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--ckpt_path", type=str, required=True) | ||
parser.add_argument("--trans_vae", type=str, default="./models/TransparentVAE.pth") | ||
parser.add_argument("--output_dir", type=str, default="./flux-layer-outputs") | ||
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype") | ||
parser.add_argument("--seed", type=int, default=11111) | ||
parser.add_argument("--steps", type=int, default=50) | ||
parser.add_argument("--guidance", type=float, default=3.5) | ||
parser.add_argument("--prompt", type=str, default="glass bottle, high quality") | ||
parser.add_argument( | ||
"--lora_weights", | ||
type=str, | ||
default="./models/layerlora.safetensors", | ||
) | ||
parser.add_argument("--width", type=int, default=1024) | ||
parser.add_argument("--height", type=int, default=1024) | ||
args = parser.parse_args() | ||
|
||
pipe = FluxPipeline.from_pretrained(args.ckpt_path, torch_dtype=torch.bfloat16).to('cuda') | ||
pipe.load_lora_weights(args.lora_weights) | ||
|
||
trans_vae = TransparentVAE(pipe.vae, pipe.vae.dtype) | ||
trans_vae.load_state_dict(torch.load(args.trans_vae), strict=False) | ||
trans_vae.to('cuda') | ||
|
||
print("all loaded") | ||
|
||
img = generate_img(pipe, trans_vae, args) | ||
|
||
# save image | ||
os.makedirs(args.output_dir, exist_ok=True) | ||
output_path = os.path.join(args.output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") | ||
img.save(output_path) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.