Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initial nnx port #3

Merged
merged 34 commits into from
Sep 29, 2024
Merged
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
89aaaeb
chore: inital commit for porting
ariG23498 Sep 2, 2024
e5c0c2d
adding torch code from original repository
ariG23498 Sep 2, 2024
14398a3
dot pdt attn in jnp
ariG23498 Sep 3, 2024
1a00d0c
more changes to init
ariG23498 Sep 3, 2024
3dfd926
start porting
ariG23498 Sep 3, 2024
3bcd923
adding everything to uv
ariG23498 Sep 4, 2024
3e09e36
adding ruff and formating
ariG23498 Sep 4, 2024
92e2c5b
fix(modules/autoencoder): correct invocation of GroupNorm and Conv
SauravMaheshkar Sep 8, 2024
d894b3a
fix(modules/layers): correct invocation of nnx modules
SauravMaheshkar Sep 8, 2024
0414a02
fix: jaxify codebase
SauravMaheshkar Sep 13, 2024
38e97b2
chore: allow for non GPU installation
SauravMaheshkar Sep 13, 2024
363f74e
feat(ci): use uv in CI
SauravMaheshkar Sep 13, 2024
4a3948e
fix: tmp mark test as xfail
SauravMaheshkar Sep 13, 2024
29ad2a0
fix: use Sequential for the middle blocks
SauravMaheshkar Sep 13, 2024
b4b3c37
chore: drop devcontainer
SauravMaheshkar Sep 13, 2024
8c6a590
feat: add loop to cli
SauravMaheshkar Sep 13, 2024
393429a
chore: fix just cmd
SauravMaheshkar Sep 13, 2024
76c05b2
style(mypy): disable no-redef
SauravMaheshkar Sep 13, 2024
25c2bf1
fix: return numpy tensors from tokenizer
SauravMaheshkar Sep 13, 2024
52464ef
feat: use Array from chex
SauravMaheshkar Sep 13, 2024
eadbc6d
docs: update docstrings + use chex
SauravMaheshkar Sep 13, 2024
cb9a5a7
fix: nnx modules use __call__
SauravMaheshkar Sep 13, 2024
1bfa089
feat: jaxify prepare fn
SauravMaheshkar Sep 13, 2024
919aace
fix: nnx modules use __call__
SauravMaheshkar Sep 13, 2024
6c97c8d
docs: add docstrings to denoise fn
SauravMaheshkar Sep 13, 2024
df0e52b
feat: to_device >> device_put
SauravMaheshkar Sep 13, 2024
09d1e08
feat: add dtypes and param dtypes
SauravMaheshkar Sep 13, 2024
52776d6
docs: docstrings for Identity module
SauravMaheshkar Sep 13, 2024
8a64274
feat: explicitly specify the dtypes for QKNorm in SelfAttention module
SauravMaheshkar Sep 13, 2024
3d9dc3f
feat: add tests for embedding layer
SauravMaheshkar Sep 15, 2024
b48c3cc
feat: use official flux as optional deps
SauravMaheshkar Sep 16, 2024
7775066
style: enforce isort
SauravMaheshkar Sep 16, 2024
af24690
feat: add tests for layers
SauravMaheshkar Sep 16, 2024
0c59aa1
feat: add tests for modulation and self-attn
SauravMaheshkar Sep 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
style: enforce isort
SauravMaheshkar committed Sep 16, 2024
commit 7775066c428b82542d812ed9c9e467a32b4c0557
4 changes: 2 additions & 2 deletions jflux/autoencoder.py
Original file line number Diff line number Diff line change
@@ -2,10 +2,10 @@

import jax
import jax.numpy as jnp
from jax.typing import DTypeLike
from chex import Array
from flax import nnx
from einops import rearrange
from flax import nnx
from jax.typing import DTypeLike

from jflux.layers import DiagonalGaussian
from jflux.sampling import interpolate
4 changes: 2 additions & 2 deletions jflux/cli.py
Original file line number Diff line number Diff line change
@@ -4,11 +4,11 @@
from dataclasses import dataclass
from glob import iglob

from fire import Fire

import jax
import jax.numpy as jnp
from fire import Fire
from jax.typing import DTypeLike

from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from jflux.util import (
configs,
6 changes: 3 additions & 3 deletions jflux/conditioner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from chex import Array
from flax import nnx
from transformers import (
FlaxCLIPTextModel,
CLIPTokenizer,
FlaxCLIPTextModel,
FlaxT5EncoderModel,
T5Tokenizer,
)
from chex import Array
from flax import nnx


class HFEmbedder(nnx.Module):
4 changes: 2 additions & 2 deletions jflux/layers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import math
from functools import partial

import jax
import jax.numpy as jnp
from jax.typing import DTypeLike
from chex import Array
from flax import nnx
from functools import partial
from jax.typing import DTypeLike

from jflux.math import rope

4 changes: 2 additions & 2 deletions jflux/math.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import typing

from einops import rearrange
from chex import Array
import jax
from chex import Array
from einops import rearrange
from jax import numpy as jnp


9 changes: 5 additions & 4 deletions jflux/model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from dataclasses import dataclass

import jax.dtypes
from chex import Array
from flax import nnx
from jax import numpy as jnp
from chex import Array
from jax.typing import DTypeLike

from jflux.layers import (
Identity,
Embed,
AdaLayerNorm,
Embed,
Identity,
timestep_embedding,
)
from jflux.modules import DoubleStreamBlock, SingleStreamBlock, MLPEmbedder
from jflux.modules import DoubleStreamBlock, MLPEmbedder, SingleStreamBlock


@dataclass
14 changes: 9 additions & 5 deletions jflux/modules.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import typing
from dataclasses import dataclass
from chex import Array
from flax import nnx

import jax
import jax.numpy as jnp
from chex import Array
from einops import rearrange
from flax import nnx
from jax.typing import DTypeLike

from einops import rearrange
from jflux.math import attention
from jflux.layers import QKNorm
from jflux.math import attention


class MLPEmbedder(nnx.Module):
@@ -31,6 +32,9 @@ def __init__(
dtype: DTypeLike = jax.dtypes.bfloat16,
param_dtype: DTypeLike = None,
) -> None:
if param_dtype is None:
param_dtype = dtype

self.in_layer = nnx.Linear(
in_features=in_dim,
out_features=hidden_dim,
@@ -40,7 +44,7 @@ def __init__(
rngs=rngs,
)
self.out_layer = nnx.Linear(
in_features=in_dim,
in_features=hidden_dim,
out_features=hidden_dim,
dtype=dtype,
param_dtype=param_dtype,
10 changes: 5 additions & 5 deletions jflux/sampling.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import math
from typing import Callable

from einops import rearrange, repeat

import jax
from jax.image import ResizeMethod
from chex import Array, Device, PRNGKey
from einops import rearrange, repeat
from jax import numpy as jnp
from jax.image import ResizeMethod
from jax.typing import DTypeLike
from chex import Array, PRNGKey, Device
from jflux.model import Flux

from jflux.conditioner import HFEmbedder
from jflux.model import Flux


def get_noise(
6 changes: 3 additions & 3 deletions jflux/util.py
Original file line number Diff line number Diff line change
@@ -2,15 +2,15 @@
from dataclasses import dataclass

import jax
from jax.typing import DTypeLike
from flax import nnx
from jax import numpy as jnp
from huggingface_hub import hf_hub_download
from jax import numpy as jnp
from jax.typing import DTypeLike
from safetensors.numpy import load_file as load_sft

from jflux.model import Flux, FluxParams
from jflux.autoencoder import AutoEncoder, AutoEncoderParams
from jflux.conditioner import HFEmbedder
from jflux.model import Flux, FluxParams


@dataclass
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
@@ -21,5 +21,5 @@ test:

# Basic linting
lint:
ruff check jflux
ruff check jflux --fix
mypy jflux
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -31,6 +31,12 @@ dev-dependencies = [
flux-jax = { workspace = true }
flux = { git = "https://github.com/black-forest-labs/flux.git" }

[tool.ruff.lint]
select = ["I001"]

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.mypy]
disable_error_code = "no-redef"