diff --git a/jflux/cli.py b/jflux/cli.py index 7597d98..628b4bd 100644 --- a/jflux/cli.py +++ b/jflux/cli.py @@ -10,13 +10,7 @@ from jax.typing import DTypeLike from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack -from jflux.util import ( - configs, - load_ae, - load_clip, - load_flow_model, - load_t5, -) +from jflux.util import configs, load_ae, load_clip, load_flow_model, load_t5 @dataclass diff --git a/jflux/math.py b/jflux/math.py index 73a30a4..b502b03 100644 --- a/jflux/math.py +++ b/jflux/math.py @@ -1,41 +1,21 @@ -import typing - import jax from chex import Array from einops import rearrange +from flax import nnx from jax import numpy as jnp -@typing.no_type_check def attention(q: Array, k: Array, v: Array, pe: Array) -> Array: - # TODO (ariG23498): Change all usage of attention to use this function q, k = apply_rope(q, k, pe) - # jax expects this shape - x = rearrange(x, "B H L D -> B L H D") # noqa - x = jax.nn.dot_product_attention(q, k, v) - x = rearrange(x, "B L H D -> B L (H D)") # reshape again + x = nnx.dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") return x def rope(pos: Array, dim: int, theta: int) -> Array: - """ - Generate Rotary Position Embedding (RoPE) for positional encoding. - - Args: - pos (Array): Positional values, typically a sequence of positions in an array format. - dim (int): The embedding dimension, which must be an even number. - theta (int): A scaling parameter for RoPE that controls the frequency range of rotations. - - Returns: - Array: Rotary embeddings with cosine and sine components for each position and dimension. - """ - - # Embedding dimension must be an even number assert dim % 2 == 0 - - # Generate the RoPE embeddings scale = jnp.arange(0, dim, 2, dtype=jnp.float64, device=pos.device) / dim omega = 1.0 / (theta**scale) out = jnp.einsum("...n,d->...nd", pos, omega) @@ -45,26 +25,10 @@ def rope(pos: Array, dim: int, theta: int) -> Array: def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: - """ - Apply RoPE to the input query and key tensors. - - Args: - xq (Array): Query tensor. - xk (Array): Key tensor. - freqs_cis (Array): RoPE frequencies. - - Returns: - tuple[Array, Array]: Query and key tensors with RoPE applied. - """ - # Reshape and typecast the input tensors xq_ = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 1, 2) - - # Apply RoPE to the input tensors xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - - # Reshape and typecast the output tensors return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype( xk.dtype ) diff --git a/jflux/modules/layers.py b/jflux/modules/layers.py index 130defe..16bf91d 100644 --- a/jflux/modules/layers.py +++ b/jflux/modules/layers.py @@ -1,29 +1,18 @@ import math -from functools import partial +from dataclasses import dataclass import jax import jax.numpy as jnp from chex import Array +from einops import rearrange from flax import nnx from jax.typing import DTypeLike -from jflux.math import rope +from jflux.math import attention, rope -class Embed(nnx.Module): - """ - Embedding module for Positional Embeddings. - - Args: - dim (int): Dimension of the embedding. - theta (int): theta parameter for the RoPE embedding - axes_dim (list[int]): List of axes dimensions. - - Returns: - RoPE embeddings - """ - - def __init__(self, dim: int, theta: int, axes_dim: list[int]) -> None: +class EmbedND(nnx.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): self.dim = dim self.theta = theta self.axes_dim = axes_dim @@ -32,13 +21,12 @@ def __call__(self, ids: Array) -> Array: n_axes = ids.shape[-1] emb = jnp.concat( [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], - axis=-3, + dim=-3, ) - return jnp.expand_dims(emb, 1) + return jnp.expand_dims(emb, axis=1) -@partial(jax.jit, static_argnums=(1, 2, 3)) def timestep_embedding( t: Array, dim: int, max_period=10000, time_factor: float = 1000.0 ) -> Array: @@ -46,72 +34,89 @@ def timestep_embedding( Generate timestep embeddings. Args: - t (Array): An array of timesteps to be embedded. - dim (int): The desired dimensionality of the output embedding. - max_period (int, optional): The maximum period for the sinusoidal functions. Defaults to 10000. - time_factor (float, optional): A scaling factor applied to the input timesteps. Defaults to 1000.0. + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + time_factor: Tensor of positional embeddings. Returns: timestep embeddings. """ - # Pre-Processing: - # * scales the input timesteps by the given time factor t = time_factor * t half = dim // 2 - # Determine frequencies using exponential decay freqs = jnp.exp( -math.log(max_period) * jnp.arange(start=0, stop=half, dtype=jnp.float32, device=t.device) / half - ) + ).astype(dtype=t.device) - # Create embeddings by concatenating sine and cosines args = t[:, None].astype(jnp.float32) * freqs[None] embedding = jnp.concat([jnp.cos(args), jnp.sin(args)], axis=-1) - # Handle odd dimensions if dim % 2: embedding = jnp.concat([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1) - # If timestamps are floating types ensure so is the embedding if jnp.issubdtype(t.device(), jnp.floating): embedding = embedding.astype(t.device()) return embedding -class QKNorm(nnx.Module): - """ - Normalization layer for query and key values. +class MLPEmbedder(nnx.Module): + def __init__( + self, + in_dim: int, + hidden_dim: int, + rngs: nnx.Rngs, + param_dtype: DTypeLike = jax.dtypes.bfloat16, + ): + self.in_layer = nnx.Linear( + in_features=in_dim, + out_features=hidden_dim, + use_bias=True, + rngs=rngs, + param_dtype=param_dtype, + ) + self.silu = nnx.silu + self.out_layer = nnx.Linear( + in_features=hidden_dim, + out_features=hidden_dim, + use_bias=True, + rngs=rngs, + param_dtype=param_dtype, + ) - Args: - dim (int): Dimension of the hidden layer. - rngs (nnx.Rngs): RNGs for the layer. - dtype (DTypeLike): Data type for the layer. - param_dtype (DTypeLike): Data type for the layer parameters. + def __call__(self, x: Array) -> Array: + return self.out_layer(self.silu(self.in_layer(x))) - Returns: - Normalized query and key values - """ +class RMSNorm(nnx.Module): def __init__( self, dim: int, rngs: nnx.Rngs, - dtype: DTypeLike = jax.dtypes.bfloat16, - param_dtype: DTypeLike = None, - ) -> None: - if param_dtype is None: - param_dtype = dtype + param_dtype: DTypeLike = jax.dtypes.bfloat16, + ): + self.scale = nnx.Variable(jnp.ones(dim, dtype=param_dtype)) - # RMS Normalization for query and key - self.query_norm = nnx.RMSNorm( - dim, rngs=rngs, dtype=dtype, param_dtype=param_dtype - ) - self.key_norm = nnx.RMSNorm( - dim, rngs=rngs, dtype=dtype, param_dtype=param_dtype - ) + def __call__(self, x: Array): + x_dtype = x.dtype + x = x.astype(jnp.float32) + rrms = jnp.reciprocal(jnp.sqrt(jnp.mean(x**2, axis=-1, keepdims=True) + 1e-6)) + return (x * rrms).astype(dtype=x_dtype) * self.scale + + +class QKNorm(nnx.Module): + def __init__( + self, + dim: int, + rngs: nnx.Rngs, + param_dtype: DTypeLike = jax.dtypes.bfloat16, + ): + self.query_norm = RMSNorm(dim, rngs=rngs, param_dtype=param_dtype) + self.key_norm = RMSNorm(dim, rngs=rngs, param_dtype=param_dtype) def __call__(self, q: Array, k: Array, v: Array) -> tuple[Array, Array]: q = self.query_norm(q) @@ -119,95 +124,327 @@ def __call__(self, q: Array, k: Array, v: Array) -> tuple[Array, Array]: return q.to_device(v.device), k.to_device(v.device) -class AdaLayerNorm(nnx.Module): - """ - Normalization layer modified to incorporate timestep embeddings. +class SelfAttention(nnx.Module): + def __init__( + self, + dim: int, + rngs: nnx.Rngs, + param_dtype: DTypeLike = jax.dtypes.bfloat16, + num_heads: int = 8, + qkv_bias: bool = False, + ): + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nnx.Linear( + in_features=dim, + out_features=dim * 3, + use_bias=qkv_bias, + rngs=rngs, + param_dtype=param_dtype, + ) + self.norm = QKNorm(dim=head_dim, rngs=rngs, param_dtype=param_dtype) + self.proj = nnx.Linear( + in_features=dim, + out_features=dim, + use_bias=True, + rngs=rngs, + param_dtype=param_dtype, + ) - Args: - hidden_size (int): Dimension of the hidden layer. - patch_size (int): patch size. - out_channels (int): Number of output channels. - rngs (nnx.Rngs): RNGs for the layer. - dtype (DTypeLike): Data type for the layer. - param_dtype (DTypeLike): Data type for the layer parameters. + def __call__(self, x: Array, pe: Array) -> Array: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x - Returns: - Normalized layer incorporating timestep embeddings. - """ +@dataclass +class ModulationOut: + shift: Array + scale: Array + gate: Array + + +class Modulation(nnx.Module): + def __init__( + self, + dim: int, + double: bool, + rngs: nnx.Rngs, + param_dtype: DTypeLike = jax.dtypes.bfloat16, + ): + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nnx.Linear( + in_features=dim, + out_features=self.multiplier * dim, + use_bias=True, + rngs=rngs, + param_dtype=param_dtype, + ) + + def __call__(self, vec: Array) -> tuple[ModulationOut, ModulationOut | None]: + out = jnp.split(self.lin(nnx.silu(vec))[:, None, :], self.multiplier, axis=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nnx.Module): def __init__( self, hidden_size: int, - patch_size: int, - out_channels: int, + num_heads: int, + mlp_ratio: float, rngs: nnx.Rngs, - dtype: DTypeLike = jax.dtypes.bfloat16, - param_dtype: DTypeLike = None, - ) -> None: - if param_dtype is None: - param_dtype = dtype + param_dtype: DTypeLike = jax.dtypes.bfloat16, + qkv_bias: bool = False, + ): + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation( + dim=hidden_size, double=True, rngs=rngs, param_dtype=param_dtype + ) + self.img_norm1 = nnx.LayerNorm( + num_features=hidden_size, + use_scale=False, + epsilon=1e-6, + rngs=rngs, + param_dtype=param_dtype, + ) + self.img_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + rngs=rngs, + param_dtype=param_dtype, + ) - self.norm_final = nnx.LayerNorm( - hidden_size, epsilon=1e-6, rngs=rngs, dtype=dtype, param_dtype=param_dtype + self.img_norm2 = nnx.LayerNorm( + num_features=hidden_size, + use_scale=False, + epsilon=1e-6, + rngs=rngs, + param_dtype=param_dtype, ) - self.linear = nnx.Linear( - hidden_size, - patch_size * patch_size * out_channels, - use_bias=True, + self.img_mlp = nnx.Sequential( + nnx.Linear( + in_features=hidden_size, + out_features=mlp_hidden_dim, + use_bias=True, + rngs=rngs, + param_dtype=param_dtype, + ), + nnx.gelu, + nnx.Linear( + in_features=mlp_hidden_dim, + out_features=hidden_size, + use_bias=True, + rngs=rngs, + param_dtype=param_dtype, + ), + ) + + self.txt_mod = Modulation( + dim=hidden_size, double=True, rngs=rngs, param_dtype=param_dtype + ) + self.txt_norm1 = nnx.LayerNorm( + num_features=hidden_size, + use_scale=False, + epsilon=1e-6, rngs=rngs, - dtype=dtype, param_dtype=param_dtype, ) - self.adaLN_modulation = nnx.Sequential( - nnx.silu, + self.txt_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + rngs=rngs, + param_dtype=param_dtype, + ) + + self.txt_norm2 = nnx.LayerNorm( + num_features=hidden_size, + use_scale=False, + epsilon=1e-6, + rngs=rngs, + param_dtype=param_dtype, + ) + self.txt_mlp = nnx.Sequential( nnx.Linear( - hidden_size, - 2 * hidden_size, + in_features=hidden_size, + out_features=mlp_hidden_dim, + use_bias=True, + rngs=rngs, + param_dtype=param_dtype, + ), + nnx.gelu, + nnx.Linear( + in_features=mlp_hidden_dim, + out_features=hidden_size, use_bias=True, rngs=rngs, - dtype=dtype, param_dtype=param_dtype, ), ) - def __call__(self, x: Array, vec: Array) -> Array: - modulation_output = self.adaLN_modulation(vec) - shift, scale = jnp.split(modulation_output, indices_or_sections=2, axis=1) - x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] - x = self.linear(x) - return x + def __call__( + self, img: Array, txt: Array, vec: Array, pe: Array + ) -> tuple[Array, Array]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange( + img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange( + txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + # run actual attention + q = jnp.concat((txt_q, img_q), axis=2) + k = jnp.concat((txt_k, img_k), axis=2) + v = jnp.concat((txt_v, img_v), axis=2) -class DiagonalGaussian(nnx.Module): - """ - A module that represents a diagonal Gaussian distribution. + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - Args: - sample (bool, optional): Whether to sample from the distribution. Defaults to True. - chunk_dim (int, optional): The dimension along which to chunk the input. Defaults to 1. + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp( + (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift + ) - Returns: - Array: The output array representing the sampled or mean values. + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp( + (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift + ) + return img, txt + + +class SingleStreamBlock(nnx.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. """ - def __init__(self, key: Array, sample: bool = True, chunk_dim: int = 1) -> None: - self.sample = sample - self.chunk_dim = chunk_dim - self.key = key + def __init__( + self, + hidden_size: int, + num_heads: int, + rngs: nnx.Rngs, + param_dtype: DTypeLike = jax.dtypes.bfloat16, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nnx.Linear( + in_features=hidden_size, + out_features=hidden_size * 3 + self.mlp_hidden_dim, + rngs=rngs, + param_dtype=param_dtype, + ) + # proj and mlp_out + self.linear2 = nnx.Linear( + in_features=hidden_size + self.mlp_hidden_dim, + out_features=hidden_size, + rngs=rngs, + param_dtype=param_dtype, + ) + + self.norm = QKNorm(dim=head_dim, rngs=rngs, param_dtype=param_dtype) - def __call__(self, z: Array) -> Array: - mean, logvar = jnp.split(z, indices_or_sections=2, axis=self.chunk_dim) - if self.sample: - std = jnp.exp(0.5 * logvar) - return mean + std * jax.random.normal( - key=self.key, shape=mean.shape, dtype=z.dtype - ) - else: - return mean + self.hidden_size = hidden_size + self.pre_norm = nnx.LayerNorm( + num_features=hidden_size, + use_scale=False, + epsilon=1e-6, + rngs=rngs, + param_dtype=param_dtype, + ) + self.mlp_act = nnx.gelu + self.modulation = Modulation( + hidden_size, double=False, rngs=rngs, param_dtype=param_dtype + ) -class Identity(nnx.Module): - """Identity module.""" + def __call__(self, x: Array, vec: Array, pe: Array) -> Array: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = jnp.split( + self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 + ) - def __call__(self, x: Array) -> Array: + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(jnp.concatenate((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + +class LastLayer(nnx.Module): + def __init__( + self, + hidden_size: int, + patch_size: int, + out_channels: int, + rngs: nnx.Rngs, + param_dtype: DTypeLike = jax.dtypes.bfloat16, + ): + self.norm_final = nnx.LayerNorm( + num_features=hidden_size, + use_scale=False, + epsilon=1e-6, + rngs=rngs, + param_dtype=param_dtype, + ) + self.linear = nnx.Linear( + in_features=hidden_size, + out_features=patch_size * patch_size * out_channels, + use_bias=True, + rngs=rngs, + param_dtype=param_dtype, + ) + self.adaLN_modulation = nnx.Sequential( + nnx.silu, + nnx.Linear( + in_features=hidden_size, + out_features=2 * hidden_size, + use_bias=True, + rngs=rngs, + param_dtype=param_dtype, + ), + ) + + def forward(self, x: Array, vec: Array) -> Array: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) return x diff --git a/jflux/sampling.py b/jflux/sampling.py index bc2503b..e7359a0 100644 --- a/jflux/sampling.py +++ b/jflux/sampling.py @@ -77,9 +77,9 @@ def prepare( img = repeat(img, "1 ... -> bs ...", bs=bs) # prepare image ids - img_ids = jnp.zeros(shape=(h // 2, w // 2), device=device) - img_ids[..., 1] = img_ids[..., 1] + jnp.arange(h // 2)[:, None] - img_ids[..., 2] = img_ids[..., 2] + jnp.arange(w // 2)[None, :] + img_ids = jnp.zeros(shape=(h // 2, w // 2, 3), device=device) + img_ids = img_ids.at[..., 1].set(jnp.arange(h // 2)[:, None]) + img_ids = img_ids.at[..., 2].set(jnp.arange(w // 2)[None, :]) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) # prepare txt diff --git a/jflux/util.py b/jflux/util.py index 722a2e4..53928d9 100644 --- a/jflux/util.py +++ b/jflux/util.py @@ -8,9 +8,9 @@ from jax.typing import DTypeLike from safetensors.numpy import load_file as load_sft +from jflux.model import Flux, FluxParams from jflux.modules.autoencoder import AutoEncoder, AutoEncoderParams from jflux.modules.conditioner import HFEmbedder -from jflux.model import Flux, FluxParams @dataclass diff --git a/pyproject.toml b/pyproject.toml index 89efb10..cc34d86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [project] -name = "flux-jax" +name = "jflux" version = "0.1.0" description = "Inference codebase for Flux in Jax" readme = "README.md" @@ -9,7 +9,7 @@ dependencies = [ "einops>=0.8.0", "fire>=0.6.0", "flax>=0.9.0", - "flux-jax", + "jflux", # FIXME: Allow for local installation without GPUs as well `jax[cuda12]` "jax>=0.4.31", "mypy>=1.11.2", @@ -28,7 +28,7 @@ dev-dependencies = [ ] [tool.uv.sources] -flux-jax = { workspace = true } +jflux = { workspace = true } flux = { git = "https://github.com/black-forest-labs/flux.git" } [tool.ruff.lint] diff --git a/tests/modules/test_autoencoder.py b/tests/modules/test_autoencoder.py index a4aa73a..1fea730 100644 --- a/tests/modules/test_autoencoder.py +++ b/tests/modules/test_autoencoder.py @@ -1,27 +1,25 @@ -from einops import rearrange import jax.numpy as jnp +import numpy as np import torch +from einops import rearrange from flax import nnx - from flux.modules.autoencoder import AttnBlock as TorchAttnBlock -from flux.modules.autoencoder import ResnetBlock as TorchResnetBlock -from flux.modules.autoencoder import Downsample as TorchDownsample -from flux.modules.autoencoder import Upsample as TorchUpsample -from flux.modules.autoencoder import Encoder as TorchEncoder -from flux.modules.autoencoder import Decoder as TorchDecoder from flux.modules.autoencoder import AutoEncoder as TorchAutoEncoder from flux.modules.autoencoder import AutoEncoderParams as TorchAutoEncoderParams +from flux.modules.autoencoder import Decoder as TorchDecoder +from flux.modules.autoencoder import Downsample as TorchDownsample +from flux.modules.autoencoder import Encoder as TorchEncoder +from flux.modules.autoencoder import ResnetBlock as TorchResnetBlock +from flux.modules.autoencoder import Upsample as TorchUpsample from jflux.modules.autoencoder import AttnBlock as JaxAttnBlock -from jflux.modules.autoencoder import ResnetBlock as JaxResnetBlock -from jflux.modules.autoencoder import Downsample as JaxDownsample -from jflux.modules.autoencoder import Upsample as JaxUpsample -from jflux.modules.autoencoder import Encoder as JaxEncoder -from jflux.modules.autoencoder import Decoder as JaxDecoder from jflux.modules.autoencoder import AutoEncoder as JaxAutoEncoder from jflux.modules.autoencoder import AutoEncoderParams as JaxAutoEncoderParams - -import numpy as np +from jflux.modules.autoencoder import Decoder as JaxDecoder +from jflux.modules.autoencoder import Downsample as JaxDownsample +from jflux.modules.autoencoder import Encoder as JaxEncoder +from jflux.modules.autoencoder import ResnetBlock as JaxResnetBlock +from jflux.modules.autoencoder import Upsample as JaxUpsample from tests.utils import torch2jax diff --git a/tests/modules/test_layers.py b/tests/modules/test_layers.py new file mode 100644 index 0000000..ceed318 --- /dev/null +++ b/tests/modules/test_layers.py @@ -0,0 +1,289 @@ +import jax.numpy as jnp +import numpy as np +import torch +from einops import rearrange +from flax import nnx +from flux.modules.layers import DoubleStreamBlock as TorchDoubleStreamBlock +from flux.modules.layers import MLPEmbedder as TorchMLPEmbedder +from flux.modules.layers import Modulation as TorchModulation +from flux.modules.layers import QKNorm as TorchQKNorm +from flux.modules.layers import RMSNorm as TorchRMSNorm + +from jflux.modules.layers import DoubleStreamBlock as JaxDoubleStreamBlock +from jflux.modules.layers import MLPEmbedder as JaxMLPEmbedder +from jflux.modules.layers import Modulation as JaxModulation +from jflux.modules.layers import QKNorm as JaxQKNorm +from jflux.modules.layers import RMSNorm as JaxRMSNorm +from tests.utils import torch2jax + + +def port_mlp_embedder( + jax_mlp_embedder: JaxMLPEmbedder, torch_mlp_embedder: JaxMLPEmbedder +): + # linear layers + jax_mlp_embedder.in_layer.kernel.value = torch2jax( + rearrange(torch_mlp_embedder.in_layer.weight, "i o -> o i") + ) + jax_mlp_embedder.in_layer.bias.value = torch2jax(torch_mlp_embedder.in_layer.bias) + + jax_mlp_embedder.out_layer.kernel.value = torch2jax( + rearrange(torch_mlp_embedder.out_layer.weight, "i o -> o i") + ) + jax_mlp_embedder.out_layer.bias.value = torch2jax(torch_mlp_embedder.out_layer.bias) + return jax_mlp_embedder + + +def port_rms_norm(jax_rms_norm: JaxRMSNorm, torch_rms_norm: TorchRMSNorm): + jax_rms_norm.scale.value = torch2jax(torch_rms_norm.scale) + return jax_rms_norm + + +def port_qknorm(jax_qknorm: JaxQKNorm, torch_qknorm: TorchQKNorm): + # query norm + jax_qknorm.query_norm = port_rms_norm( + jax_rms_norm=jax_qknorm.query_norm, + torch_rms_norm=torch_qknorm.query_norm, + ) + # key norm + jax_qknorm.key_norm = port_rms_norm( + jax_rms_norm=jax_qknorm.key_norm, + torch_rms_norm=torch_qknorm.key_norm, + ) + + return jax_qknorm + + +def port_modulation( + jax_modulation: JaxModulation, + torch_modulation: TorchModulation, +): + jax_modulation.lin.kernel.value = torch2jax( + rearrange(torch_modulation.lin.weight, "i o -> o i") + ) + jax_modulation.lin.bias.value = torch2jax(torch_modulation.lin.bias) + return jax_modulation + + +class LayersTestCase(np.testing.TestCase): + def test_mlp_embedder(self): + # Initialize layers + in_dim = 32 + hidden_dim = 64 + rngs = nnx.Rngs(default=42) + param_dtype = jnp.float32 + + torch_mlp_embedder = TorchMLPEmbedder( + in_dim=in_dim, + hidden_dim=hidden_dim, + ) + jax_mlp_embedder = JaxMLPEmbedder( + in_dim=in_dim, + hidden_dim=hidden_dim, + rngs=rngs, + param_dtype=param_dtype, + ) + + # port the weights of the torch model into jax + jax_mlp_embedder = port_mlp_embedder( + jax_mlp_embedder=jax_mlp_embedder, torch_mlp_embedder=torch_mlp_embedder + ) + + # Generate random inputs + np_input = np.random.randn(2, in_dim).astype(np.float32) + jax_input = jnp.array(np_input, dtype=jnp.float32) + torch_input = torch.from_numpy(np_input).to(torch.float32) + + np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) + + # Forward pass + torch_output = torch_mlp_embedder(torch_input) + jax_output = jax_mlp_embedder(jax_input) + + # Assertions + np.testing.assert_allclose( + np.array(jax_output), + torch_output.detach().numpy(), + rtol=1e-5, + atol=1e-5, + ) + + def test_rms_norm(self): + # Initialize the layer + dim = 3 + rngs = nnx.Rngs(default=42) + param_dtype = jnp.float32 + + torch_rms_norm = TorchRMSNorm(dim=dim) + jax_rms_norm = JaxRMSNorm(dim=dim, rngs=rngs, param_dtype=param_dtype) + + # port the weights of the torch model into jax + jax_rms_norm = port_rms_norm( + jax_rms_norm=jax_rms_norm, torch_rms_norm=torch_rms_norm + ) + + # Generate random inputs + np_input = np.random.randn(2, dim).astype(np.float32) + jax_input = jnp.array(np_input, dtype=jnp.float32) + torch_input = torch.from_numpy(np_input).to(torch.float32) + + np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) + + # Forward pass + torch_output = torch_rms_norm(torch_input) + jax_output = jax_rms_norm(jax_input) + + # Assertions + np.testing.assert_allclose( + np.array(jax_output), + torch_output.detach().numpy(), + rtol=1e-5, + atol=1e-5, + ) + + def test_qknorm(self): + # Initialize the layer + dim = 16 + seq_len = 4 + rngs = nnx.Rngs(default=42) + param_dtype = jnp.float32 + + torch_qknorm = TorchQKNorm(dim=dim) + jax_qknorm = JaxQKNorm(dim=dim, rngs=rngs, param_dtype=param_dtype) + + # port the model + jax_qknorm = port_qknorm(jax_qknorm=jax_qknorm, torch_qknorm=torch_qknorm) + + # Generate random inputs + np_q = np.random.randn(2, seq_len, dim).astype(np.float32) + np_k = np.random.randn(2, seq_len, dim).astype(np.float32) + np_v = np.random.randn(2, seq_len, dim).astype(np.float32) + + jax_q = jnp.array(np_q, dtype=jnp.float32) + torch_q = torch.from_numpy(np_q).to(torch.float32) + + jax_k = jnp.array(np_k, dtype=jnp.float32) + torch_k = torch.from_numpy(np_k).to(torch.float32) + + jax_v = jnp.array(np_v, dtype=jnp.float32) + torch_v = torch.from_numpy(np_v).to(torch.float32) + + np.testing.assert_allclose(np.array(jax_q), torch_q.numpy()) + np.testing.assert_allclose(np.array(jax_k), torch_k.numpy()) + + jax_output = jax_qknorm(q=jax_q, k=jax_k, v=jax_v) + torch_output = torch_qknorm(q=torch_q, k=torch_k, v=torch_v) + + np.testing.assert_allclose( + np.array(jax_output[0]), + torch_output[0].detach().numpy(), + rtol=1e-5, + atol=1e-5, + ) + np.testing.assert_allclose( + np.array(jax_output[1]), + torch_output[1].detach().numpy(), + rtol=1e-5, + atol=1e-5, + ) + + def test_modulation(self): + # Initialize the layer + dim = 4 + rngs = nnx.Rngs(default=42) + param_dtype = jnp.float32 + + torch_modulation = TorchModulation(dim=dim, double=True) + jax_modulation = JaxModulation( + dim=dim, double=True, rngs=rngs, param_dtype=param_dtype + ) + + jax_modulation = port_modulation( + jax_modulation=jax_modulation, + torch_modulation=torch_modulation, + ) + + # Generate random inputs + np_input = np.random.randn(2, dim).astype(np.float32) + jax_input = jnp.array(np_input, dtype=jnp.float32) + torch_input = torch.from_numpy(np_input).to(torch.float32) + + np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) + + torch_output = torch_modulation(torch_input) + jax_output = jax_modulation(jax_input) + + # Assertions + for i in range(2): + np.testing.assert_allclose( + np.array(jax_output[i].shift), + torch_output[i].shift.detach().numpy(), + rtol=1e-5, + atol=1e-5, + ) + np.testing.assert_allclose( + np.array(jax_output[i].scale), + torch_output[i].scale.detach().numpy(), + rtol=1e-5, + atol=1e-5, + ) + np.testing.assert_allclose( + np.array(jax_output[i].gate), + torch_output[i].gate.detach().numpy(), + rtol=1e-5, + atol=1e-5, + ) + + def test_double_stream_block(self): + # Initialize layer + hidden_size = 64 + num_heads = 8 + mlp_ratio = 4.0 + qkv_bias = False + rngs = nnx.Rngs(default=42) + param_dtype = jnp.float32 + + # Initialize the DoubleStreamBlock + torch_double_stream_block = TorchDoubleStreamBlock( + hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias + ) + jax_double_stream_block = JaxDoubleStreamBlock( + hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + rngs=rngs, + param_dtype=param_dtype, + ) + + # Create the dummy inputs + np_img = np.random.randn(2, 10, hidden_size).astype(np.float32) # Batch size 2, sequence length 10, hidden size 64 (image input) + np_txt = np.random.randn(2, 15, hidden_size).astype(np.float32) # Batch size 2, sequence length 15, hidden size 64 (text input) + np_vec = np.random.randn(2, hidden_size).astype(np.float32) # Batch size 2, hidden size 64 (modulation vector) + np_pe = np.random.randn(2, 25, hidden_size).astype(np.float32) # Batch size 2, total length 25 (10 + 15), hidden size 64 (positional embedding) + + jax_img = jnp.array(np_img, dtype=jnp.float32) + jax_txt = jnp.array(np_txt, dtype=jnp.float32) + jax_vec = jnp.array(np_vec, dtype=jnp.float32) + jax_pe = jnp.array(np_pe, dtype=jnp.float32) + + torch_img = torch.from_numpy(np_img).to(torch.float32) + torch_txt = torch.from_numpy(np_txt).to(torch.float32) + torch_vec = torch.from_numpy(np_vec).to(torch.float32) + torch_pe = torch.from_numpy(np_pe).to(torch.float32) + + # Forward pass through the DoubleStreamBlock + torch_img_out, torch_txt_out = torch_double_stream_block( + img=torch_img, + txt=torch_txt, + vec=torch_vec, + pe=torch_pe, + ) + jax_img_out, jax_txt_out = jax_double_stream_block( + img=jax_img, + txt=jax_txt, + vec=jax_vec, + pe=jax_pe, + ) diff --git a/tests/test_basic.py b/tests/test_basic.py deleted file mode 100644 index 322fe62..0000000 --- a/tests/test_basic.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Basic test to always pass""" - -from __future__ import annotations - - -def test_always_passes(): - """Simple Test""" - assert True diff --git a/tests/test_layers.py b/tests/test_layers.py deleted file mode 100644 index 236432b..0000000 --- a/tests/test_layers.py +++ /dev/null @@ -1,103 +0,0 @@ -import chex -import jax -import jax.numpy as jnp -import torch -from flax import nnx -from flux.modules.autoencoder import DiagonalGaussian as PytorchDiagonalGaussian -from flux.modules.layers import EmbedND, LastLayer -from flux.modules.layers import QKNorm as PytorchQKNorm - -from jflux.modules.layers import AdaLayerNorm, Embed -from jflux.modules.layers import DiagonalGaussian as JaxDiagonalGaussian -from jflux.modules.layers import QKNorm as JaxQKNorm -from tests.utils import torch2jax - - -class LayersTestCase(chex.TestCase): - def test_embed(self): - # Initialize layers - pytorch_embed_layer = EmbedND(512, 10000, [64, 64, 64, 64]) - jax_embed_layer = Embed(512, 10000, [64, 64, 64, 64]) - - # Generate random inputs - torch_ids = torch.randint(0, 10000, (1, 32, 4)) - jax_ids = torch2jax(torch_ids) - - # Forward pass - jax_output = jax_embed_layer(jax_ids) - pytorch_output = pytorch_embed_layer(torch_ids) - - # Assertions - chex.assert_equal_shape([jax_output, torch2jax(pytorch_output)]) - chex.assert_trees_all_close( - jax_output, torch2jax(pytorch_output), rtol=1e-3, atol=1e-3 - ) - - def test_qk_norm(self): - # Initialize layers - pytorch_qk_norm_layer = PytorchQKNorm(512) - jax_qk_norm_layer = JaxQKNorm(512, rngs=nnx.Rngs(default=42), dtype=jnp.float32) - - # Generate random inputs - torch_query = torch.randn(1, 32, 512, dtype=torch.float32) - torch_key = torch.randn(1, 32, 512, dtype=torch.float32) - torch_value = torch.randn(1, 32, 512, dtype=torch.float32) - jax_query = torch2jax(torch_query) - jax_key = torch2jax(torch_key) - jax_value = torch2jax(torch_value) - - # Forward pass - jax_output = jax_qk_norm_layer(jax_query, jax_key, jax_value) - pytorch_output = pytorch_qk_norm_layer(torch_query, torch_key, torch_value) - - # Assertions - assert len(jax_output) == len(pytorch_output) - for i in range(len(jax_output)): - chex.assert_equal_shape([jax_output[i], torch2jax(pytorch_output[i])]) - chex.assert_trees_all_close( - jax_output[i], torch2jax(pytorch_output[i]), rtol=1e-3, atol=1e-3 - ) - - def test_adalayer_norm(self): - # Initialize layers - pytorch_adalayer_norm_layer = LastLayer( - hidden_size=512, - patch_size=16, - out_channels=512, - ) - jax_adalayer_norm_layer = AdaLayerNorm( - hidden_size=512, - patch_size=16, - out_channels=512, - rngs=nnx.Rngs(default=42), - dtype=jnp.float32, - ) - - # Generate random inputs - torch_hidden = torch.randn(1, 32, 512, dtype=torch.float32) - torch_vec = torch.randn(1, 512, dtype=torch.float32) - jax_hidden = torch2jax(torch_hidden) - jax_vec = torch2jax(torch_vec) - - # Forward pass - jax_output = jax_adalayer_norm_layer(jax_hidden, jax_vec) - pytorch_output = pytorch_adalayer_norm_layer(torch_hidden, torch_vec) - - # Assertions - chex.assert_equal_shape([jax_output, torch2jax(pytorch_output)]) - - def test_diagonal_gaussian(self): - # Initialize layers - pytorch_diagonal_gaussian_layer = PytorchDiagonalGaussian() - jax_diagonal_gaussian_layer = JaxDiagonalGaussian(key=jax.random.key(42)) - - # Generate random inputs - torch_input = torch.randn(1, 32, 512, dtype=torch.float32) - jax_input = torch2jax(torch_input) - - # Forward pass - jax_output = jax_diagonal_gaussian_layer(jax_input) - pytorch_output = pytorch_diagonal_gaussian_layer(torch_input) - - # Assertions - chex.assert_equal_shape([jax_output, torch2jax(pytorch_output)]) diff --git a/tests/test_math.py b/tests/test_math.py index 20e36bc..4c4b897 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -1,61 +1,171 @@ import unittest import jax.numpy as jnp -import pytest +import numpy as np +import torch +from flux.math import apply_rope as torch_apply_rope +from flux.math import attention as torch_attention +from flux.math import rope as torch_rope -from jflux.math import apply_rope, attention, rope +from jflux.math import apply_rope as jax_apply_rope +from jflux.math import attention as jax_attention +from jflux.math import rope as jax_rope -class TestAttentionMechanism(unittest.TestCase): - def setUp(self): - self.batch_size = 2 - self.num_heads = 4 - self.seq_len = 8 - self.dim = 64 - self.theta = 10000 +class TestMath(np.testing.TestCase): + def test_rope(self): + B, L, H, D = ( + 2, + 4, + 2, + 8, + ) # Batch size, sequence length, number of heads, embedding dimension + theta = 10000 - self.q = jnp.ones((self.batch_size, self.num_heads, self.seq_len, self.dim)) - self.k = jnp.ones((self.batch_size, self.num_heads, self.seq_len, self.dim)) - self.v = jnp.ones((self.batch_size, self.num_heads, self.seq_len, self.dim)) + # Position indices (e.g., positions in the sequence) + np_positions = ( + np.expand_dims(np.arange(L), 0).repeat(B, 1).astype(np.int32) + ) # Shape: [B, L] + torch_positions = torch.from_numpy(np_positions).to(torch.int32) + jax_positions = jnp.array(np_positions, dtype=jnp.int32) - def test_rope(self): - pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0) - pos = jnp.repeat(pos, self.batch_size, axis=0) + np.testing.assert_allclose(np.array(jax_positions), torch_positions.numpy()) - rope_output = rope(pos, self.dim, self.theta) - expected_shape = (self.batch_size, self.seq_len, self.dim // 2, 2, 2) + torch_pe = torch_rope(pos=torch_positions, dim=D, theta=theta) + jax_pe = jax_rope( + pos=jax_positions, dim=D, theta=theta + ) # Shape: [B, L, D/2, 2, 2] - self.assertEqual( - rope_output.shape, expected_shape, "rope function output shape is incorrect" + np.testing.assert_allclose( + np.array(jax_pe), + torch_pe.numpy(), + rtol=1e-5, + atol=1e-5, ) - @pytest.mark.xfail def test_apply_rope(self): - pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0) - pos = jnp.repeat(pos, self.batch_size, axis=0) + B, H, L, D = ( + 1, + 24, + 4336, + 128, + ) + theta = 10000 + + # Inputs + np_q = np.random.randn(B, H, L, D).astype(np.float32) + np_k = np.random.randn(B, H, L, D).astype(np.float32) + + jax_q = jnp.array(np_q, dtype=jnp.float32) + jax_k = jnp.array(np_k, dtype=jnp.float32) + + torch_q = torch.from_numpy(np_q).to(torch.float32) + torch_k = torch.from_numpy(np_k).to(torch.float32) + + np.testing.assert_allclose(np.array(jax_q), torch_q.numpy()) + np.testing.assert_allclose(np.array(jax_k), torch_k.numpy()) + + # Position indices (e.g., positions in the sequence) + np_positions = np.random.randn(1, L).astype(np.float32) + torch_positions = torch.from_numpy(np_positions).to(torch.float32) + jax_positions = jnp.array(np_positions, dtype=jnp.float32) + + np.testing.assert_allclose(np.array(jax_positions), torch_positions.numpy()) + + torch_pe = torch_rope(pos=torch_positions, dim=(3072 // 24), theta=theta) + jax_pe = jax_rope(pos=jax_positions, dim=(3072 // 24), theta=theta) + + np.testing.assert_allclose( + np.array(jax_pe), + torch_pe.numpy(), + rtol=1e-5, + atol=1e-5, + ) + + torch_pe = torch_pe.unsqueeze(1).expand( + -1, H, -1, -1, -1, -1 + ) # Shape: [B, H, L, D//2, 2, 2] + jax_pe = jnp.repeat(jnp.expand_dims(jax_pe, axis=1), repeats=H, axis=1) - freqs_cis = rope(pos, self.dim, self.theta) - xq_out, xk_out = apply_rope(self.q, self.k, freqs_cis) + # Apply RoPE to q and k + torch_q_rotated, torch_k_rotated = torch_apply_rope( + xq=torch_q, xk=torch_k, freqs_cis=torch_pe + ) + jax_q_rotated, jax_k_rotated = jax_apply_rope( + xq=jax_q, xk=jax_k, freqs_cis=jax_pe + ) - self.assertEqual( - xq_out.shape, self.q.shape, "apply_rope xq output shape is incorrect" + np.testing.assert_allclose( + np.array(jax_q_rotated), + torch_q_rotated.numpy(), + rtol=1e-5, + atol=1e-5, ) - self.assertEqual( - xk_out.shape, self.k.shape, "apply_rope xk output shape is incorrect" + np.testing.assert_allclose( + np.array(jax_k_rotated), + torch_k_rotated.numpy(), + rtol=1e-5, + atol=1e-5, ) - @pytest.mark.xfail - def test_attention(self): - pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0) - pos = jnp.repeat(pos, self.batch_size, axis=0) + # def test_attention(self): + # # Generate random inputs + # np_input = np.random.randn(2, 32, 4, 4).astype(np.float32) + # jax_input = jnp.array(np_input, dtype=jnp.float32) + # torch_input = torch.from_numpy(np_input).to(torch.float32) - freqs_cis = rope(pos, self.dim, self.theta) - attention_output = attention(self.q, self.k, self.v, freqs_cis) + # np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) - expected_shape = (self.batch_size, self.seq_len, self.num_heads * self.dim) + # # Forward pass + # torch_output = torch_downsample(torch_input) + # jax_output = jax_downsample(rearrange(jax_input, "b c h w -> b h w c")) - self.assertEqual( - attention_output.shape, - expected_shape, - "attention function output shape is incorrect", - ) + # # Assertions + # np.testing.assert_allclose( + # np.array(rearrange(jax_output, "b h w c -> b c h w")), + # torch_output.detach().numpy(), + # rtol=1e-5, + # atol=1e-5, + # ) + + # def test_rope(self): + # pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0) + # pos = jnp.repeat(pos, self.batch_size, axis=0) + + # rope_output = rope(pos, self.dim, self.theta) + # expected_shape = (self.batch_size, self.seq_len, self.dim // 2, 2, 2) + + # self.assertEqual( + # rope_output.shape, expected_shape, "rope function output shape is incorrect" + # ) + + # @pytest.mark.xfail + # def test_apply_rope(self): + # pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0) + # pos = jnp.repeat(pos, self.batch_size, axis=0) + + # freqs_cis = rope(pos, self.dim, self.theta) + # xq_out, xk_out = apply_rope(self.q, self.k, freqs_cis) + + # self.assertEqual( + # xq_out.shape, self.q.shape, "apply_rope xq output shape is incorrect" + # ) + # self.assertEqual( + # xk_out.shape, self.k.shape, "apply_rope xk output shape is incorrect" + # ) + + # @pytest.mark.xfail + # def test_attention(self): + # pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0) + # pos = jnp.repeat(pos, self.batch_size, axis=0) + + # freqs_cis = rope(pos, self.dim, self.theta) + # attention_output = attention(self.q, self.k, self.v, freqs_cis) + + # expected_shape = (self.batch_size, self.seq_len, self.num_heads * self.dim) + + # self.assertEqual( + # attention_output.shape, + # expected_shape, + # "attention function output shape is incorrect", + # ) diff --git a/tests/test_modules.py b/tests/test_modules.py new file mode 100644 index 0000000..424293c --- /dev/null +++ b/tests/test_modules.py @@ -0,0 +1,85 @@ +import chex +import jax.numpy as jnp +import pytest +import torch +from flax import nnx +from flux.modules.layers import MLPEmbedder +from flux.modules.layers import Modulation as PytorchModulation +from flux.modules.layers import SelfAttention as PytorchSelfAttention + +from jflux.modules.layers import MLPEmbedder as JaxMLPEmbedder +from jflux.modules.layers import Modulation as JaxModulation +from jflux.modules.layers import SelfAttention as JaxSelfAttention +from tests.utils import torch2jax + + +class ModulesTestCase(chex.TestCase): + def test_mlp_embedder(self): + # Initialize layers + pytorch_mlp_embedder = MLPEmbedder(in_dim=512, hidden_dim=256) + jax_mlp_embedder = JaxMLPEmbedder( + in_dim=512, + hidden_dim=256, + rngs=nnx.Rngs(default=42), + param_dtype=jnp.float32, + ) + + # Generate random inputs + torch_input = torch.randn(1, 32, 512, dtype=torch.float32) + jax_input = torch2jax(torch_input) + + # Forward pass + jax_output = jax_mlp_embedder(jax_input) + pytorch_output = pytorch_mlp_embedder(torch_input) + + # Assertions + chex.assert_equal_shape([jax_output, torch2jax(pytorch_output)]) + + @pytest.mark.skip(reason="Blocked by apply_rope") + def test_self_attention(self): + # Initialize layers + pytorch_self_attention = PytorchSelfAttention(dim=512) + jax_self_attention = JaxSelfAttention( + dim=512, rngs=nnx.Rngs(default=42), param_dtype=jnp.float32 + ) + + # Generate random inputs + torch_input = torch.randn(1, 32, 512, dtype=torch.float32) + torch_pe = torch.randn(1, 32, 512, dtype=torch.float32) + jax_input = torch2jax(torch_input) + jax_pe = torch2jax(torch_pe) + + # Forward pass + jax_output = jax_self_attention(jax_input, jax_pe) + pytorch_output = pytorch_self_attention(torch_input, torch_pe) + + # Assertions + chex.assert_equal_shape([jax_output, torch2jax(pytorch_output)]) + + def test_modulation(self): + # Initialize layers + pytorch_modulation = PytorchModulation(dim=512, double=True) + jax_modulation = JaxModulation( + dim=512, double=True, rngs=nnx.Rngs(default=42), param_dtype=jnp.float32 + ) + + # Generate random inputs + torch_input = torch.randn(1, 32, 512, dtype=torch.float32) + jax_input = torch2jax(torch_input) + + # Forward pass + jax_output = jax_modulation(jax_input) + pytorch_output = pytorch_modulation(torch_input) + + # Convert Modulation output to individual tensors + jax_tensors = [jax_output[0].shift, jax_output[0].scale, jax_output[0].gate] + torch_tensors = [ + torch2jax(pytorch_output[0].shift), + torch2jax(pytorch_output[0].scale), + torch2jax(pytorch_output[0].gate), + ] + + # Assertions + assert len(jax_output) == len(pytorch_output) + for i in range(len(jax_output)): + chex.assert_equal_shape([jax_tensors[i], torch_tensors[i]]) diff --git a/uv.lock b/uv.lock index 48553f2..0930d42 100644 --- a/uv.lock +++ b/uv.lock @@ -207,46 +207,6 @@ dependencies = [ { name = "transformers" }, ] -[[package]] -name = "flux-jax" -version = "0.1.0" -source = { virtual = "." } -dependencies = [ - { name = "einops" }, - { name = "fire" }, - { name = "flax" }, - { name = "jax" }, - { name = "mypy" }, - { name = "pillow" }, - { name = "ruff" }, - { name = "transformers" }, -] - -[package.dev-dependencies] -dev = [ - { name = "flux" }, - { name = "pytest" }, -] - -[package.metadata] -requires-dist = [ - { name = "einops", specifier = ">=0.8.0" }, - { name = "fire", specifier = ">=0.6.0" }, - { name = "flax", specifier = ">=0.9.0" }, - { name = "flux-jax", virtual = "." }, - { name = "jax", specifier = ">=0.4.31" }, - { name = "mypy", specifier = ">=1.11.2" }, - { name = "pillow", specifier = ">=10.4.0" }, - { name = "ruff", specifier = ">=0.6.3" }, - { name = "transformers", specifier = ">=4.44.2" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "flux", git = "https://github.com/black-forest-labs/flux.git" }, - { name = "pytest", specifier = ">=8.3.3" }, -] - [[package]] name = "fsspec" version = "2024.6.1" @@ -368,6 +328,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/13/1bb2bcb4d9f4719dd5f3d98f5c2fc2235f961ced576366b040372eebdb17/jaxlib-0.4.31-cp312-cp312-win_amd64.whl", hash = "sha256:c4bfd15315e30525514b7262d555bea00745b09ac9818bb14c20ef8afbbab072", size = 56299104 }, ] +[[package]] +name = "jflux" +version = "0.1.0" +source = { virtual = "." } +dependencies = [ + { name = "einops" }, + { name = "fire" }, + { name = "flax" }, + { name = "jax" }, + { name = "mypy" }, + { name = "pillow" }, + { name = "ruff" }, + { name = "transformers" }, +] + +[package.dev-dependencies] +dev = [ + { name = "flux" }, + { name = "pytest" }, +] + +[package.metadata] +requires-dist = [ + { name = "einops", specifier = ">=0.8.0" }, + { name = "fire", specifier = ">=0.6.0" }, + { name = "flax", specifier = ">=0.9.0" }, + { name = "jax", specifier = ">=0.4.31" }, + { name = "jflux", virtual = "." }, + { name = "mypy", specifier = ">=1.11.2" }, + { name = "pillow", specifier = ">=10.4.0" }, + { name = "ruff", specifier = ">=0.6.3" }, + { name = "transformers", specifier = ">=4.44.2" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "flux", git = "https://github.com/black-forest-labs/flux.git" }, + { name = "pytest", specifier = ">=8.3.3" }, +] + [[package]] name = "jinja2" version = "3.1.4" @@ -634,7 +634,6 @@ version = "12.1.3.1" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774 }, - { url = "https://files.pythonhosted.org/packages/c5/ef/32a375b74bea706c93deea5613552f7c9104f961b21df423f5887eca713b/nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906", size = 439918445 }, ] [[package]] @@ -643,7 +642,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015 }, - { url = "https://files.pythonhosted.org/packages/d0/56/0021e32ea2848c24242f6b56790bd0ccc8bf99f973ca790569c6ca028107/nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4", size = 10154340 }, ] [[package]] @@ -652,7 +650,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2", size = 23671734 }, - { url = "https://files.pythonhosted.org/packages/ad/1d/f76987c4f454eb86e0b9a0e4f57c3bf1ac1d13ad13cd1a4da4eb0e0c0ce9/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed", size = 19331863 }, ] [[package]] @@ -661,7 +658,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596 }, - { url = "https://files.pythonhosted.org/packages/9f/e2/7a2b4b5064af56ea8ea2d8b2776c0f2960d95c88716138806121ae52a9c9/nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344", size = 821226 }, ] [[package]] @@ -673,7 +669,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, - { url = "https://files.pythonhosted.org/packages/3f/d0/f90ee6956a628f9f04bf467932c0a25e5a7e706a684b896593c06c82f460/nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a", size = 679925892 }, ] [[package]] @@ -682,7 +677,6 @@ version = "11.0.2.54" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161 }, - { url = "https://files.pythonhosted.org/packages/f7/57/7927a3aa0e19927dfed30256d1c854caf991655d847a4e7c01fe87e3d4ac/nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253", size = 121344196 }, ] [[package]] @@ -691,7 +685,6 @@ version = "10.3.2.106" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/44/31/4890b1c9abc496303412947fc7dcea3d14861720642b49e8ceed89636705/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0", size = 56467784 }, - { url = "https://files.pythonhosted.org/packages/5c/97/4c9c7c79efcdf5b70374241d48cf03b94ef6707fd18ea0c0f53684931d0b/nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a", size = 55995813 }, ] [[package]] @@ -705,7 +698,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, - { url = "https://files.pythonhosted.org/packages/b8/80/8fca0bf819122a631c3976b6fc517c1b10741b643b94046bd8dd451522c5/nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5", size = 121643081 }, ] [[package]] @@ -717,7 +709,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, - { url = "https://files.pythonhosted.org/packages/0f/95/48fdbba24c93614d1ecd35bc6bdc6087bd17cbacc3abc4b05a9c2a1ca232/nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a", size = 195414588 }, ] [[package]] @@ -736,7 +727,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/58/8c/69c9e39cd6bfa813852a94e9bd3c075045e2707d163e9dc2326c82d2c330/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b3fd0779845f68b92063ab1393abab1ed0a23412fc520df79a8190d098b5cd6b", size = 19253287 }, { url = "https://files.pythonhosted.org/packages/a8/48/a9775d377cb95585fb188b469387f58ba6738e268de22eae2ad4cedb2c41/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl", hash = "sha256:125a6c2a44e96386dda634e13d944e60b07a0402d391a070e8fb4104b34ea1ab", size = 19725597 }, - { url = "https://files.pythonhosted.org/packages/00/d5/02af3b39427ed71e8c40b6912271499ec186a72405bcb7e4ca26ff70678c/nvidia_nvjitlink_cu12-12.6.68-py3-none-win_amd64.whl", hash = "sha256:a55744c98d70317c5e23db14866a8cc2b733f7324509e941fc96276f9f37801d", size = 161730369 }, ] [[package]] @@ -745,7 +735,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5", size = 99138 }, - { url = "https://files.pythonhosted.org/packages/b8/d7/bd7cb2d95ac6ac6e8d05bfa96cdce69619f1ef2808e072919044c2d47a8c/nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82", size = 66307 }, ] [[package]]