Skip to content

Commit

Permalink
Adding FLUX porting code (#11)
Browse files Browse the repository at this point in the history
Co-authored-by: Saurav Maheshkar <sauravvmaheshkar@gmail.com>
ariG23498 and SauravMaheshkar authored Oct 9, 2024
1 parent 9162e4d commit 0c9fb04
Showing 8 changed files with 400 additions and 146 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
This CODEOWNERS file is valid.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @SauravMaheshkar
* @SauravMaheshkar @ariG23498
5 changes: 2 additions & 3 deletions jflux/cli.py
Original file line number Diff line number Diff line change
@@ -6,13 +6,12 @@

import jax
import jax.numpy as jnp
from flax import nnx
from einops import rearrange
from fire import Fire
from flax import nnx
from jax.typing import DTypeLike

from PIL import Image

from einops import rearrange
from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from jflux.util import configs, load_ae, load_clip, load_flow_model, load_t5

4 changes: 2 additions & 2 deletions jflux/modules/conditioner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Note: This is a torch module not a Jax module
from torch import nn
from chex import Array
import jax.numpy as jnp
from chex import Array
from torch import nn
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer


6 changes: 6 additions & 0 deletions jflux/modules/layers.py
Original file line number Diff line number Diff line change
@@ -214,6 +214,7 @@ def __init__(
self.img_norm1 = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
@@ -229,6 +230,7 @@ def __init__(
self.img_norm2 = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
@@ -257,6 +259,7 @@ def __init__(
self.txt_norm1 = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
@@ -272,6 +275,7 @@ def __init__(
self.txt_norm2 = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
@@ -382,6 +386,7 @@ def __init__(
self.pre_norm = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
@@ -419,6 +424,7 @@ def __init__(
self.norm_final = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
Loading

0 comments on commit 0c9fb04

Please sign in to comment.