Skip to content

Commit

Permalink
chore: inital commit for porting
Browse files Browse the repository at this point in the history
  • Loading branch information
ariG23498 committed Sep 2, 2024
1 parent 7b744dd commit 89aaaeb
Show file tree
Hide file tree
Showing 8 changed files with 947 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .github/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# FL(U/A)X

JAX Implementation of Black Forest Labs' Flux.1 family of models


## Installation

```shell
$ python venv .venv
$ pip install -U "jax[cuda12]"
$ pip install flax
```
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ models/
notebooks/

**/.DS_Store
.*
File renamed without changes.
253 changes: 253 additions & 0 deletions src/jflux/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import os
import re
import time
from dataclasses import dataclass
from glob import iglob

from einops import rearrange
from fire import Fire
from PIL import ExifTags, Image

from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from jflux.util import (configs, embed_watermark, load_ae, load_clip,
load_flow_model, load_t5)
from transformers import pipeline

NSFW_THRESHOLD = 0.85

@dataclass
class SamplingOptions:
prompt: str
width: int
height: int
num_steps: int
guidance: float
seed: int | None


def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the prompt or write a command starting with a slash:\n"
"- '/w <width>' will set the width of the generated image\n"
"- '/h <height>' will set the height of the generated image\n"
"- '/s <seed>' sets the next seed\n"
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
"- '/n <steps>' sets the number of steps\n"
"- '/q' to quit"
)

while (prompt := input(user_question)).startswith("/"):
if prompt.startswith("/w"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, width = prompt.split()
options.width = 16 * (int(width) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height *options.width/1e6:.2f}MP)"
)
elif prompt.startswith("/h"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, height = prompt.split()
options.height = 16 * (int(height) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height *options.width/1e6:.2f}MP)"
)
elif prompt.startswith("/g"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, guidance = prompt.split()
options.guidance = float(guidance)
print(f"Setting guidance to {options.guidance}")
elif prompt.startswith("/s"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, seed = prompt.split()
options.seed = int(seed)
print(f"Setting seed to {options.seed}")
elif prompt.startswith("/n"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, steps = prompt.split()
options.num_steps = int(steps)
print(f"Setting seed to {options.num_steps}")
elif prompt.startswith("/q"):
print("Quitting")
return None
else:
if not prompt.startswith("/h"):
print(f"Got invalid command '{prompt}'\n{usage}")
print(usage)
if prompt != "":
options.prompt = prompt
return options


@torch.inference_mode()
def main(
name: str = "flux-schnell",
width: int = 1360,
height: int = 768,
seed: int | None = None,
prompt: str = (
"a photo of a forest with mist swirling around the tree trunks. The word "
'"FLUX" is painted over it in big, red brush strokes with visible texture'
),
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int | None = None,
loop: bool = False,
guidance: float = 3.5,
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
single image.
Args:
name: Name of the model to load
height: height of the sample in pixels (should be a multiple of 16)
width: width of the sample in pixels (should be a multiple of 16)
seed: Set a seed for sampling
output_name: where to save the output image, `{idx}` will be replaced
by the index of the sample
prompt: Prompt used for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
"""
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

if name not in configs:
available = ", ".join(configs.keys())
raise ValueError(f"Got unknown model name: {name}, chose from {available}")

torch_device = torch.device(device)
if num_steps is None:
num_steps = 4 if name == "flux-schnell" else 50

# allow for packing and conversion to latent space
height = 16 * (height // 16)
width = 16 * (width // 16)

output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
idx = 0

# init all components
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
clip = load_clip(torch_device)
model = load_flow_model(name, device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)

rng = torch.Generator(device="cpu")
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
)

if loop:
opts = parse_prompt(opts)

while opts is not None:
if opts.seed is None:
opts.seed = rng.seed()
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
t0 = time.perf_counter()

# prepare input
x = get_noise(
1,
opts.height,
opts.width,
device=torch_device,
dtype=torch.bfloat16,
seed=opts.seed,
)
opts.seed = None
if offload:
ae = ae.cpu()
torch.cuda.empty_cache()
t5, clip = t5.to(torch_device), clip.to(torch_device)
inp = prepare(t5, clip, x, prompt=opts.prompt)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))

# offload TEs to CPU, load model to gpu
if offload:
t5, clip = t5.cpu(), clip.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)

# denoise initial noise
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)

# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)

# decode latents to pixel space
x = unpack(x.float(), opts.height, opts.width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = ae.decode(x)
t1 = time.perf_counter()

fn = output_name.format(idx=idx)
print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
# bring into PIL format and save
x = x.clamp(-1, 1)
x = embed_watermark(x.float())
x = rearrange(x[0], "c h w -> h w c")

img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]

if nsfw_score < NSFW_THRESHOLD:
exif_data = Image.Exif()
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
exif_data[ExifTags.Base.Model] = name
if add_sampling_metadata:
exif_data[ExifTags.Base.ImageDescription] = prompt
img.save(fn, exif=exif_data, quality=95, subsampling=0)
idx += 1
else:
print("Your generated image may contain NSFW content.")

if loop:
print("-" * 80)
opts = parse_prompt(opts)
else:
opts = None


def app():
Fire(main)


if __name__ == "__main__":
app()
31 changes: 31 additions & 0 deletions src/jflux/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from jax import Array
from jax import numpy as jnp
from flax import nnx
from einops import rearrange


def attention(q: Array, k: Array, v: Array, pe: Array) -> Array:
q, k = apply_rope(q, k, pe)

x = nnx.dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")

return x


def rope(pos: Array, dim: int, theta: int) -> Array:
assert dim % 2 == 0
scale = jnp.arange(0, dim, 2, dtype=jnp.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = jnp.einsum("...n,d->...nd", pos, omega)
out = jnp.stack([jnp.cos(out), -jnp.sin(out), jnp.sin(out), jnp.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()


def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
Loading

0 comments on commit 89aaaeb

Please sign in to comment.