generated from SauravMaheshkar/python-template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
947 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,12 @@ | ||
# FL(U/A)X | ||
|
||
JAX Implementation of Black Forest Labs' Flux.1 family of models | ||
|
||
|
||
## Installation | ||
|
||
```shell | ||
$ python venv .venv | ||
$ pip install -U "jax[cuda12]" | ||
$ pip install flax | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,4 @@ models/ | |
notebooks/ | ||
|
||
**/.DS_Store | ||
.* |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
import os | ||
import re | ||
import time | ||
from dataclasses import dataclass | ||
from glob import iglob | ||
|
||
from einops import rearrange | ||
from fire import Fire | ||
from PIL import ExifTags, Image | ||
|
||
from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack | ||
from jflux.util import (configs, embed_watermark, load_ae, load_clip, | ||
load_flow_model, load_t5) | ||
from transformers import pipeline | ||
|
||
NSFW_THRESHOLD = 0.85 | ||
|
||
@dataclass | ||
class SamplingOptions: | ||
prompt: str | ||
width: int | ||
height: int | ||
num_steps: int | ||
guidance: float | ||
seed: int | None | ||
|
||
|
||
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: | ||
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" | ||
usage = ( | ||
"Usage: Either write your prompt directly, leave this field empty " | ||
"to repeat the prompt or write a command starting with a slash:\n" | ||
"- '/w <width>' will set the width of the generated image\n" | ||
"- '/h <height>' will set the height of the generated image\n" | ||
"- '/s <seed>' sets the next seed\n" | ||
"- '/g <guidance>' sets the guidance (flux-dev only)\n" | ||
"- '/n <steps>' sets the number of steps\n" | ||
"- '/q' to quit" | ||
) | ||
|
||
while (prompt := input(user_question)).startswith("/"): | ||
if prompt.startswith("/w"): | ||
if prompt.count(" ") != 1: | ||
print(f"Got invalid command '{prompt}'\n{usage}") | ||
continue | ||
_, width = prompt.split() | ||
options.width = 16 * (int(width) // 16) | ||
print( | ||
f"Setting resolution to {options.width} x {options.height} " | ||
f"({options.height *options.width/1e6:.2f}MP)" | ||
) | ||
elif prompt.startswith("/h"): | ||
if prompt.count(" ") != 1: | ||
print(f"Got invalid command '{prompt}'\n{usage}") | ||
continue | ||
_, height = prompt.split() | ||
options.height = 16 * (int(height) // 16) | ||
print( | ||
f"Setting resolution to {options.width} x {options.height} " | ||
f"({options.height *options.width/1e6:.2f}MP)" | ||
) | ||
elif prompt.startswith("/g"): | ||
if prompt.count(" ") != 1: | ||
print(f"Got invalid command '{prompt}'\n{usage}") | ||
continue | ||
_, guidance = prompt.split() | ||
options.guidance = float(guidance) | ||
print(f"Setting guidance to {options.guidance}") | ||
elif prompt.startswith("/s"): | ||
if prompt.count(" ") != 1: | ||
print(f"Got invalid command '{prompt}'\n{usage}") | ||
continue | ||
_, seed = prompt.split() | ||
options.seed = int(seed) | ||
print(f"Setting seed to {options.seed}") | ||
elif prompt.startswith("/n"): | ||
if prompt.count(" ") != 1: | ||
print(f"Got invalid command '{prompt}'\n{usage}") | ||
continue | ||
_, steps = prompt.split() | ||
options.num_steps = int(steps) | ||
print(f"Setting seed to {options.num_steps}") | ||
elif prompt.startswith("/q"): | ||
print("Quitting") | ||
return None | ||
else: | ||
if not prompt.startswith("/h"): | ||
print(f"Got invalid command '{prompt}'\n{usage}") | ||
print(usage) | ||
if prompt != "": | ||
options.prompt = prompt | ||
return options | ||
|
||
|
||
@torch.inference_mode() | ||
def main( | ||
name: str = "flux-schnell", | ||
width: int = 1360, | ||
height: int = 768, | ||
seed: int | None = None, | ||
prompt: str = ( | ||
"a photo of a forest with mist swirling around the tree trunks. The word " | ||
'"FLUX" is painted over it in big, red brush strokes with visible texture' | ||
), | ||
device: str = "cuda" if torch.cuda.is_available() else "cpu", | ||
num_steps: int | None = None, | ||
loop: bool = False, | ||
guidance: float = 3.5, | ||
offload: bool = False, | ||
output_dir: str = "output", | ||
add_sampling_metadata: bool = True, | ||
): | ||
""" | ||
Sample the flux model. Either interactively (set `--loop`) or run for a | ||
single image. | ||
Args: | ||
name: Name of the model to load | ||
height: height of the sample in pixels (should be a multiple of 16) | ||
width: width of the sample in pixels (should be a multiple of 16) | ||
seed: Set a seed for sampling | ||
output_name: where to save the output image, `{idx}` will be replaced | ||
by the index of the sample | ||
prompt: Prompt used for sampling | ||
device: Pytorch device | ||
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) | ||
loop: start an interactive session and sample multiple times | ||
guidance: guidance value used for guidance distillation | ||
add_sampling_metadata: Add the prompt to the image Exif metadata | ||
""" | ||
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) | ||
|
||
if name not in configs: | ||
available = ", ".join(configs.keys()) | ||
raise ValueError(f"Got unknown model name: {name}, chose from {available}") | ||
|
||
torch_device = torch.device(device) | ||
if num_steps is None: | ||
num_steps = 4 if name == "flux-schnell" else 50 | ||
|
||
# allow for packing and conversion to latent space | ||
height = 16 * (height // 16) | ||
width = 16 * (width // 16) | ||
|
||
output_name = os.path.join(output_dir, "img_{idx}.jpg") | ||
if not os.path.exists(output_dir): | ||
os.makedirs(output_dir) | ||
idx = 0 | ||
else: | ||
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] | ||
if len(fns) > 0: | ||
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 | ||
else: | ||
idx = 0 | ||
|
||
# init all components | ||
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) | ||
clip = load_clip(torch_device) | ||
model = load_flow_model(name, device="cpu" if offload else torch_device) | ||
ae = load_ae(name, device="cpu" if offload else torch_device) | ||
|
||
rng = torch.Generator(device="cpu") | ||
opts = SamplingOptions( | ||
prompt=prompt, | ||
width=width, | ||
height=height, | ||
num_steps=num_steps, | ||
guidance=guidance, | ||
seed=seed, | ||
) | ||
|
||
if loop: | ||
opts = parse_prompt(opts) | ||
|
||
while opts is not None: | ||
if opts.seed is None: | ||
opts.seed = rng.seed() | ||
print(f"Generating with seed {opts.seed}:\n{opts.prompt}") | ||
t0 = time.perf_counter() | ||
|
||
# prepare input | ||
x = get_noise( | ||
1, | ||
opts.height, | ||
opts.width, | ||
device=torch_device, | ||
dtype=torch.bfloat16, | ||
seed=opts.seed, | ||
) | ||
opts.seed = None | ||
if offload: | ||
ae = ae.cpu() | ||
torch.cuda.empty_cache() | ||
t5, clip = t5.to(torch_device), clip.to(torch_device) | ||
inp = prepare(t5, clip, x, prompt=opts.prompt) | ||
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) | ||
|
||
# offload TEs to CPU, load model to gpu | ||
if offload: | ||
t5, clip = t5.cpu(), clip.cpu() | ||
torch.cuda.empty_cache() | ||
model = model.to(torch_device) | ||
|
||
# denoise initial noise | ||
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) | ||
|
||
# offload model, load autoencoder to gpu | ||
if offload: | ||
model.cpu() | ||
torch.cuda.empty_cache() | ||
ae.decoder.to(x.device) | ||
|
||
# decode latents to pixel space | ||
x = unpack(x.float(), opts.height, opts.width) | ||
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): | ||
x = ae.decode(x) | ||
t1 = time.perf_counter() | ||
|
||
fn = output_name.format(idx=idx) | ||
print(f"Done in {t1 - t0:.1f}s. Saving {fn}") | ||
# bring into PIL format and save | ||
x = x.clamp(-1, 1) | ||
x = embed_watermark(x.float()) | ||
x = rearrange(x[0], "c h w -> h w c") | ||
|
||
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) | ||
nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] | ||
|
||
if nsfw_score < NSFW_THRESHOLD: | ||
exif_data = Image.Exif() | ||
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" | ||
exif_data[ExifTags.Base.Make] = "Black Forest Labs" | ||
exif_data[ExifTags.Base.Model] = name | ||
if add_sampling_metadata: | ||
exif_data[ExifTags.Base.ImageDescription] = prompt | ||
img.save(fn, exif=exif_data, quality=95, subsampling=0) | ||
idx += 1 | ||
else: | ||
print("Your generated image may contain NSFW content.") | ||
|
||
if loop: | ||
print("-" * 80) | ||
opts = parse_prompt(opts) | ||
else: | ||
opts = None | ||
|
||
|
||
def app(): | ||
Fire(main) | ||
|
||
|
||
if __name__ == "__main__": | ||
app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from jax import Array | ||
from jax import numpy as jnp | ||
from flax import nnx | ||
from einops import rearrange | ||
|
||
|
||
def attention(q: Array, k: Array, v: Array, pe: Array) -> Array: | ||
q, k = apply_rope(q, k, pe) | ||
|
||
x = nnx.dot_product_attention(q, k, v) | ||
x = rearrange(x, "B H L D -> B L (H D)") | ||
|
||
return x | ||
|
||
|
||
def rope(pos: Array, dim: int, theta: int) -> Array: | ||
assert dim % 2 == 0 | ||
scale = jnp.arange(0, dim, 2, dtype=jnp.float64, device=pos.device) / dim | ||
omega = 1.0 / (theta**scale) | ||
out = jnp.einsum("...n,d->...nd", pos, omega) | ||
out = jnp.stack([jnp.cos(out), -jnp.sin(out), jnp.sin(out), jnp.cos(out)], dim=-1) | ||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) | ||
return out.float() | ||
|
||
|
||
def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: | ||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) | ||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) | ||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] | ||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] | ||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) |
Oops, something went wrong.