diff --git a/.github/README.md b/.github/README.md index fb8d5fb..be7669f 100644 --- a/.github/README.md +++ b/.github/README.md @@ -12,7 +12,7 @@ $ uv sync ## Running ```shell -$ uv jflux +$ uv run jflux ``` ## References diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index c8756ab..0000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,20 +0,0 @@ -repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 - hooks: - - id: end-of-file-fixer - - id: trailing-whitespace - - id: check-yaml - - id: check-toml - - id: check-json - - id: check-merge-conflict - - id: requirements-txt-fixer - - id: detect-private-key - - repo: https://github.com/psf/black - rev: 24.8.0 - hooks: - - id: black - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.2 - hooks: - - id: ruff diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..270f2d7 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,15 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "[python]": { + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + }, + "editor.defaultFormatter": "charliermarsh.ruff", + }, +} \ No newline at end of file diff --git a/jflux/__main__.py b/jflux/__main__.py deleted file mode 100644 index a6a1f3d..0000000 --- a/jflux/__main__.py +++ /dev/null @@ -1,4 +0,0 @@ -from jflux.cli import app - -if __name__ == "__main__": - app() diff --git a/jflux/cli.py b/jflux/cli.py index 5c1635f..8f5a1cd 100644 --- a/jflux/cli.py +++ b/jflux/cli.py @@ -6,10 +6,9 @@ import jax import jax.numpy as jnp +import numpy as np from einops import rearrange from fire import Fire -from flax import nnx -from jax.typing import DTypeLike from PIL import Image from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack @@ -124,7 +123,8 @@ def main( by the index of the sample prompt: Prompt used for sampling device: Pytorch device - num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + num_steps: number of sampling steps + (default 4 for schnell, 50 for guidance distilled) loop: start an interactive session and sample multiple times guidance: guidance value used for guidance distillation add_sampling_metadata: Add the prompt to the image Exif metadata @@ -216,7 +216,12 @@ def main( x = x.clip(-1, 1) x = rearrange(x[0], "c h w -> h w c") - img = Image.fromarray((127.5 * (x + 1.0))) + x = 127.5 * (x + 1.0) + x_numpy = np.array(x.astype(jnp.uint8)) + img = Image.fromarray(x_numpy) + + img.save(fn, quality=95, subsampling=0) + idx += 1 if loop: print("-" * 80) diff --git a/jflux/model.py b/jflux/model.py index 48e6939..40a838d 100644 --- a/jflux/model.py +++ b/jflux/model.py @@ -49,7 +49,7 @@ def __init__(self, params: FluxParams): self.out_channels = self.in_channels if params.hidden_size % params.num_heads != 0: raise ValueError( - f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" # noqa: E501 ) pe_dim = params.hidden_size // params.num_heads if sum(params.axes_dim) != pe_dim: diff --git a/jflux/port.py b/jflux/port.py index 8bdd151..6fe2e51 100644 --- a/jflux/port.py +++ b/jflux/port.py @@ -1,5 +1,6 @@ from einops import rearrange + ############################################################################################## # AUTOENCODER MODEL PORTING ############################################################################################## @@ -481,3 +482,5 @@ def port_flux(flux, tensors): tensors=tensors, prefix="final_layer", ) + + return flux diff --git a/jflux/util.py b/jflux/util.py index e021e31..559c96b 100644 --- a/jflux/util.py +++ b/jflux/util.py @@ -1,12 +1,10 @@ import os from dataclasses import dataclass -import jax import torch # need for t5 and clip from flax import nnx from huggingface_hub import hf_hub_download from jax import numpy as jnp -from jax.typing import DTypeLike from safetensors import safe_open from jflux.model import Flux, FluxParams diff --git a/pyproject.toml b/pyproject.toml index cc34d86..d935190 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,6 @@ dependencies = [ "einops>=0.8.0", "fire>=0.6.0", "flax>=0.9.0", - "jflux", # FIXME: Allow for local installation without GPUs as well `jax[cuda12]` "jax>=0.4.31", "mypy>=1.11.2", @@ -22,6 +21,7 @@ dependencies = [ jflux = "jflux.cli:app" [tool.uv] +package = true dev-dependencies = [ "flux", "pytest>=8.3.3", @@ -32,7 +32,10 @@ jflux = { workspace = true } flux = { git = "https://github.com/black-forest-labs/flux.git" } [tool.ruff.lint] -select = ["I001"] +select = ["E", "F", "I001", "W"] + +[tool.ruff.lint.isort] +lines-after-imports = 2 [tool.ruff.lint.pydocstyle] convention = "google" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/test_autoencoder.py b/tests/modules/test_autoencoder.py index 677bae9..b19eea6 100644 --- a/tests/modules/test_autoencoder.py +++ b/tests/modules/test_autoencoder.py @@ -1,5 +1,6 @@ +import chex +import jax import jax.numpy as jnp -import numpy as np import torch from einops import rearrange from flax import nnx @@ -291,7 +292,7 @@ def port_autoencoder( return jax_autoencoder -class AutoEncodersTestCase(np.testing.TestCase): +class AutoEncodersTestCase(chex.TestCase): def test_attn_block(self): # Initialize layers in_channels = 32 @@ -309,20 +310,23 @@ def test_attn_block(self): ) # Generate random inputs - np_input = np.random.randn(2, 32, 4, 4).astype(np.float32) - jax_input = jnp.array(np_input, dtype=jnp.float32) - torch_input = torch.from_numpy(np_input).to(torch.float32) + jax_input = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(2, 32, 4, 4), dtype=jnp.float32 + ) + torch_input = torch.from_numpy(jax_input.__array__()).to(torch.float32) - np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) + chex.assert_trees_all_close( + jax_input, torch2jax(torch_input), rtol=1e-5, atol=1e-5 + ) # Forward pass torch_output = torch_attn_block(torch_input) jax_output = jax_attn_block(rearrange(jax_input, "b c h w -> b h w c")) # Assertions - np.testing.assert_allclose( - np.array(rearrange(jax_output, "b h w c -> b c h w")), - torch_output.detach().numpy(), + chex.assert_trees_all_close( + rearrange(jax_output, "b h w c -> b c h w"), + torch2jax(torch_output), rtol=1e-5, atol=1e-5, ) @@ -350,20 +354,23 @@ def test_resnet_block(self): ) # Generate random inputs - np_input = np.random.randn(2, 32, 4, 4).astype(np.float32) - jax_input = jnp.array(np_input, dtype=jnp.float32) - torch_input = torch.from_numpy(np_input).to(torch.float32) + jax_input = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(2, 32, 4, 4), dtype=jnp.float32 + ) + torch_input = torch.from_numpy(jax_input.__array__()).to(torch.float32) - np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) + chex.assert_trees_all_close( + jax_input, torch2jax(torch_input), rtol=1e-5, atol=1e-5 + ) # Forward pass torch_output = torch_resnet_block(torch_input) jax_output = jax_resnet_block(rearrange(jax_input, "b c h w -> b h w c")) # Assertions - np.testing.assert_allclose( - np.array(rearrange(jax_output, "b h w c -> b c h w")), - torch_output.detach().numpy(), + chex.assert_trees_all_close( + rearrange(jax_output, "b h w c -> b c h w"), + torch2jax(torch_output), rtol=1e-5, atol=1e-5, ) @@ -387,20 +394,23 @@ def test_downsample(self): ) # Generate random inputs - np_input = np.random.randn(2, 32, 4, 4).astype(np.float32) - jax_input = jnp.array(np_input, dtype=jnp.float32) - torch_input = torch.from_numpy(np_input).to(torch.float32) + jax_input = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(2, 32, 4, 4), dtype=jnp.float32 + ) + torch_input = torch.from_numpy(jax_input.__array__()).to(torch.float32) - np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) + chex.assert_trees_all_close( + jax_input, torch2jax(torch_input), rtol=1e-5, atol=1e-5 + ) # Forward pass torch_output = torch_downsample(torch_input) jax_output = jax_downsample(rearrange(jax_input, "b c h w -> b h w c")) # Assertions - np.testing.assert_allclose( - np.array(rearrange(jax_output, "b h w c -> b c h w")), - torch_output.detach().numpy(), + chex.assert_trees_all_close( + rearrange(jax_output, "b h w c -> b c h w"), + torch2jax(torch_output), rtol=1e-5, atol=1e-5, ) @@ -424,20 +434,23 @@ def test_upsample(self): ) # Generate random inputs - np_input = np.random.randn(2, 32, 4, 4).astype(np.float32) - jax_input = jnp.array(np_input, dtype=jnp.float32) - torch_input = torch.from_numpy(np_input).to(torch.float32) + jax_input = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(2, 32, 4, 4), dtype=jnp.float32 + ) + torch_input = torch.from_numpy(jax_input.__array__()).to(torch.float32) - np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) + chex.assert_trees_all_close( + jax_input, torch2jax(torch_input), rtol=1e-5, atol=1e-5 + ) # Forward pass torch_output = torch_upsample(torch_input) jax_output = jax_upsample(rearrange(jax_input, "b c h w -> b h w c")) # Assertions - np.testing.assert_allclose( - np.array(rearrange(jax_output, "b h w c -> b c h w")), - torch_output.detach().numpy(), + chex.assert_trees_all_close( + rearrange(jax_output, "b h w c -> b c h w"), + torch2jax(torch_output), rtol=1e-5, atol=1e-5, ) @@ -476,22 +489,25 @@ def test_encoder(self): jax_encoder = port_encoder(jax_encoder=jax_encoder, torch_encoder=torch_encoder) # Generate random inputs - np_input = np.random.randn(1, in_channels, resolution, resolution).astype( - np.float32 + jax_input = jax.random.normal( + key=jax.random.PRNGKey(42), + shape=(1, in_channels, resolution, resolution), + dtype=jnp.float32, ) - jax_input = jnp.array(np_input, dtype=jnp.float32) - torch_input = torch.from_numpy(np_input).to(torch.float32) + torch_input = torch.from_numpy(jax_input.__array__()).to(torch.float32) - np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) + chex.assert_trees_all_close( + jax_input, torch2jax(torch_input), rtol=1e-5, atol=1e-5 + ) # Forward pass torch_output = torch_encoder(torch_input) jax_output = jax_encoder(rearrange(jax_input, "b c h w -> b h w c")) # Assertions - np.testing.assert_allclose( - np.array(rearrange(jax_output, "b h w c -> b c h w")), - torch_output.detach().numpy(), + chex.assert_trees_all_close( + rearrange(jax_output, "b h w c -> b c h w"), + torch2jax(torch_output), rtol=1e-5, atol=1e-5, ) @@ -533,22 +549,30 @@ def test_decoder(self): jax_decoder = port_decoder(jax_decoder=jax_decoder, torch_decoder=torch_decoder) # Generate random inputs - np_input = np.random.randn( - 1, z_channels, resolution // len(ch_mult), resolution // len(ch_mult) - ).astype(np.float32) - jax_input = jnp.array(np_input, dtype=jnp.float32) - torch_input = torch.from_numpy(np_input).to(torch.float32) + jax_input = jax.random.normal( + key=jax.random.PRNGKey(42), + shape=( + 1, + z_channels, + resolution // len(ch_mult), + resolution // len(ch_mult), + ), + dtype=jnp.float32, + ) + torch_input = torch.from_numpy(jax_input.__array__()).to(torch.float32) - np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) + chex.assert_trees_all_close( + jax_input, torch2jax(torch_input), rtol=1e-5, atol=1e-5 + ) # Forward pass torch_output = torch_decoder(torch_input) jax_output = jax_decoder(rearrange(jax_input, "b c h w -> b h w c")) # Assertions - np.testing.assert_allclose( - np.array(rearrange(jax_output, "b h w c -> b c h w")), - torch_output.detach().numpy(), + chex.assert_trees_all_close( + rearrange(jax_output, "b h w c -> b c h w"), + torch2jax(torch_output), rtol=1e-5, atol=1e-5, ) @@ -591,8 +615,8 @@ def test_autoencoder(self): z_channels=z_channels, scale_factor=scale_factor, shift_factor=shift_factor, - rngs=nnx.Rngs(default=42), - param_dtype=jnp.float32, + rngs=rngs, + param_dtype=param_dtype, ) torch_autoencoder = TorchAutoEncoder(params=torch_params) @@ -608,23 +632,26 @@ def test_autoencoder(self): torch_autoencoder=torch_autoencoder, ) - # inputs - np_input = np.random.randn(1, in_channels, resolution, resolution).astype( - np.float32 + # Generate random inputs + jax_input = jax.random.normal( + key=jax.random.PRNGKey(42), + shape=(1, in_channels, resolution, resolution), + dtype=jnp.float32, ) - jax_input = jnp.array(np_input, dtype=jnp.float32) - torch_input = torch.from_numpy(np_input).to(torch.float32) + torch_input = torch.from_numpy(jax_input.__array__()).to(torch.float32) - np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) + chex.assert_trees_all_close( + jax_input, torch2jax(torch_input), rtol=1e-5, atol=1e-5 + ) # forward pass torch_output = torch_autoencoder(torch_input) jax_output = jax_autoencoder(jax_input) # Assertions - np.testing.assert_allclose( - np.array(jax_output), - torch_output.detach().numpy(), + chex.assert_trees_all_close( + jax_output, + torch2jax(torch_output), rtol=1e-5, atol=1e-5, ) diff --git a/tests/modules/test_layers.py b/tests/modules/test_layers.py index ef54900..71533c5 100644 --- a/tests/modules/test_layers.py +++ b/tests/modules/test_layers.py @@ -1,11 +1,10 @@ +import chex import jax import jax.numpy as jnp -import numpy as np import torch from einops import rearrange, repeat from flax import nnx from flux.modules.layers import DoubleStreamBlock as TorchDoubleStreamBlock -from flux.modules.layers import EmbedND as TorchEmbedND from flux.modules.layers import MLPEmbedder as TorchMLPEmbedder from flux.modules.layers import Modulation as TorchModulation from flux.modules.layers import QKNorm as TorchQKNorm @@ -145,16 +144,19 @@ def port_double_stream_block( return jax_double_stream_block -class LayersTestCase(np.testing.TestCase): +class LayersTestCase(chex.TestCase): def test_timestep_embedding(self): t_vec_torch = torch.tensor([1.0], dtype=torch.float32) t_vec_jax = jnp.array([1.0], dtype=jnp.float32) jax_output = jax_timestep_embedding(t=t_vec_jax, dim=256) torch_output = torch_timesetp_embedding(t=t_vec_torch, dim=256) - print(jax_output.shape) - np.testing.assert_allclose( - np.array(jax_output), torch_output.numpy(), atol=1e-4, rtol=1e-4 + + chex.assert_trees_all_close( + jax_output, + torch2jax(torch_output), + rtol=1e-4, + atol=1e-4, ) def test_mlp_embedder(self): @@ -181,20 +183,26 @@ def test_mlp_embedder(self): ) # Generate random inputs - np_input = np.random.randn(1, in_dim).astype(np.float32) - jax_input = jnp.array(np_input, dtype=jnp.float32) - torch_input = torch.from_numpy(np_input).to(torch.float32) + jax_input = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(1, in_dim), dtype=jnp.float32 + ) + torch_input = torch.from_numpy(jax_input.__array__()).to(torch.float32) - np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) + chex.assert_trees_all_close( + jax_input, + torch2jax(torch_input), + rtol=1e-5, + atol=1e-5, + ) # Forward pass torch_output = torch_mlp_embedder(torch_input) jax_output = jax_mlp_embedder(jax_input) # Assertions - np.testing.assert_allclose( - np.array(jax_output), - torch_output.detach().numpy(), + chex.assert_trees_all_close( + jax_output, + torch2jax(torch_output), rtol=1e-5, atol=1e-5, ) @@ -214,20 +222,26 @@ def test_rms_norm(self): ) # Generate random inputs - np_input = np.random.randn(2, dim).astype(np.float32) - jax_input = jnp.array(np_input, dtype=jnp.float32) - torch_input = torch.from_numpy(np_input).to(torch.float32) + jax_input = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(2, dim), dtype=jnp.float32 + ) + torch_input = torch.from_numpy(jax_input.__array__()).to(torch.float32) - np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) + chex.assert_trees_all_close( + jax_input, + torch2jax(torch_input), + rtol=1e-5, + atol=1e-5, + ) # Forward pass torch_output = torch_rms_norm(torch_input) jax_output = jax_rms_norm(jax_input) # Assertions - np.testing.assert_allclose( - np.array(jax_output), - torch_output.detach().numpy(), + chex.assert_trees_all_close( + jax_output, + torch2jax(torch_output), rtol=1e-5, atol=1e-5, ) @@ -246,37 +260,49 @@ def test_qknorm(self): jax_qknorm = port_qknorm(jax_qknorm=jax_qknorm, torch_qknorm=torch_qknorm) # Generate random inputs - np_q = np.random.randn(2, seq_len, dim).astype(np.float32) - np_k = np.random.randn(2, seq_len, dim).astype(np.float32) - np_v = np.random.randn(2, seq_len, dim).astype(np.float32) - - jax_q = jnp.array(np_q, dtype=jnp.float32) - torch_q = torch.from_numpy(np_q).to(torch.float32) - - jax_k = jnp.array(np_k, dtype=jnp.float32) - torch_k = torch.from_numpy(np_k).to(torch.float32) - - jax_v = jnp.array(np_v, dtype=jnp.float32) - torch_v = torch.from_numpy(np_v).to(torch.float32) - - np.testing.assert_allclose(np.array(jax_q), torch_q.numpy()) - np.testing.assert_allclose(np.array(jax_k), torch_k.numpy()) + jax_q = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(2, seq_len, dim), dtype=jnp.float32 + ) + jax_k = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(2, seq_len, dim), dtype=jnp.float32 + ) + jax_v = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(2, seq_len, dim), dtype=jnp.float32 + ) - jax_output = jax_qknorm(q=jax_q, k=jax_k, v=jax_v) - torch_output = torch_qknorm(q=torch_q, k=torch_k, v=torch_v) + torch_q = torch.from_numpy(jax_q.__array__()).to(torch.float32) + torch_k = torch.from_numpy(jax_k.__array__()).to(torch.float32) + torch_v = torch.from_numpy(jax_v.__array__()).to(torch.float32) - np.testing.assert_allclose( - np.array(jax_output[0]), - torch_output[0].detach().numpy(), + chex.assert_trees_all_close( + jax_q, + torch2jax(torch_q), rtol=1e-5, atol=1e-5, ) - np.testing.assert_allclose( - np.array(jax_output[1]), - torch_output[1].detach().numpy(), + chex.assert_trees_all_close( + jax_k, + torch2jax(torch_k), rtol=1e-5, atol=1e-5, ) + chex.assert_trees_all_close( + jax_v, + torch2jax(torch_v), + rtol=1e-5, + atol=1e-5, + ) + + jax_output = jax_qknorm(q=jax_q, k=jax_k, v=jax_v) + torch_output = torch_qknorm(q=torch_q, k=torch_k, v=torch_v) + + for i in range(len(jax_output)): + chex.assert_trees_all_close( + jax_output[i], + torch2jax(torch_output[i]), + rtol=1e-5, + atol=1e-5, + ) def test_modulation(self): # Initialize the layer @@ -295,32 +321,38 @@ def test_modulation(self): ) # Generate random inputs - np_input = np.random.randn(2, dim).astype(np.float32) - jax_input = jnp.array(np_input, dtype=jnp.float32) - torch_input = torch.from_numpy(np_input).to(torch.float32) + jax_input = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(2, dim), dtype=jnp.float32 + ) + torch_input = torch.from_numpy(jax_input.__array__()).to(torch.float32) - np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) + chex.assert_trees_all_close( + jax_input, + torch2jax(torch_input), + rtol=1e-5, + atol=1e-5, + ) torch_output = torch_modulation(torch_input) jax_output = jax_modulation(jax_input) # Assertions - for i in range(2): - np.testing.assert_allclose( - np.array(jax_output[i].shift), - torch_output[i].shift.detach().numpy(), + for i in range(len(jax_output)): + chex.assert_trees_all_close( + jax_output[i].shift, + torch2jax(torch_output[i].shift), rtol=1e-5, atol=1e-5, ) - np.testing.assert_allclose( - np.array(jax_output[i].scale), - torch_output[i].scale.detach().numpy(), + chex.assert_trees_all_close( + jax_output[i].scale, + torch2jax(torch_output[i].scale), rtol=1e-5, atol=1e-5, ) - np.testing.assert_allclose( - np.array(jax_output[i].gate), - torch_output[i].gate.detach().numpy(), + chex.assert_trees_all_close( + jax_output[i].gate, + torch2jax(torch_output[i].gate), rtol=1e-5, atol=1e-5, ) @@ -356,20 +388,23 @@ def test_double_stream_block(self): ) # Create the dummy inputs - np_img = np.random.randn(1, 4080, hidden_size).astype(np.float32) - np_txt = np.random.randn(1, 256, hidden_size).astype(np.float32) - np_vec = np.random.randn(1, hidden_size).astype(np.float32) - np_pe = np.random.randn(1, 1, 4336, 64, 2, 2).astype(np.float32) - - jax_img = jnp.array(np_img, dtype=jnp.float32) - jax_txt = jnp.array(np_txt, dtype=jnp.float32) - jax_vec = jnp.array(np_vec, dtype=jnp.float32) - jax_pe = jnp.array(np_pe, dtype=jnp.float32) + jax_img = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(1, 4080, hidden_size), dtype=jnp.float32 + ) + jax_txt = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(1, 256, hidden_size), dtype=jnp.float32 + ) + jax_vec = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(1, hidden_size), dtype=jnp.float32 + ) + jax_pe = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(1, 1, 4336, 64, 2, 2), dtype=jnp.float32 + ) - torch_img = torch.from_numpy(np_img).to(torch.float32) - torch_txt = torch.from_numpy(np_txt).to(torch.float32) - torch_vec = torch.from_numpy(np_vec).to(torch.float32) - torch_pe = torch.from_numpy(np_pe).to(torch.float32) + torch_img = torch.from_numpy(jax_img.__array__()).to(torch.float32) + torch_txt = torch.from_numpy(jax_txt.__array__()).to(torch.float32) + torch_vec = torch.from_numpy(jax_vec.__array__()).to(torch.float32) + torch_pe = torch.from_numpy(jax_pe.__array__()).to(torch.float32) # Forward pass through the DoubleStreamBlock torch_img_out, torch_txt_out = torch_double_stream_block( @@ -387,7 +422,7 @@ def test_double_stream_block(self): def test_embednd(self): # noise - bs, c, h, w = (1, 16, 96, 170) + bs, _, h, w = (1, 16, 96, 170) img_ids = jnp.zeros((h // 2, w // 2, 3)) img_ids = img_ids.at[..., 1].set(jnp.arange(h // 2)[:, None]) diff --git a/tests/test_math.py b/tests/test_math.py index c10a47a..1a603c6 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -1,20 +1,19 @@ -import unittest - +import chex +import jax import jax.numpy as jnp -import numpy as np import torch from flux.math import apply_rope as torch_apply_rope -from flux.math import attention as torch_attention from flux.math import rope as torch_rope from jflux.math import apply_rope as jax_apply_rope -from jflux.math import attention as jax_attention from jflux.math import rope as jax_rope +from .utils import torch2jax + -class TestMath(np.testing.TestCase): +class TestMath(chex.TestCase): def test_rope(self): - B, L, H, D = ( + B, L, _, D = ( 2, 4, 2, @@ -22,18 +21,24 @@ def test_rope(self): ) # Batch size, sequence length, number of heads, embedding dimension theta = 10000 - np_positions = np.expand_dims(np.arange(L), 0).repeat(B, 1).astype(np.int32) - torch_positions = torch.from_numpy(np_positions).to(torch.int32) - jax_positions = jnp.array(np_positions, dtype=jnp.int32) + jax_positions = jnp.expand_dims(jnp.arange(L, dtype=jnp.int32), axis=0).repeat( + B, axis=1 + ) + torch_positions = torch.from_numpy(jax_positions.__array__()).to(torch.int32) - np.testing.assert_allclose(np.array(jax_positions), torch_positions.numpy()) + chex.assert_trees_all_close( + jax_positions, + torch2jax(torch_positions), + rtol=1e-5, + atol=1e-5, + ) torch_pe = torch_rope(pos=torch_positions, dim=D, theta=theta) jax_pe = jax_rope(pos=jax_positions, dim=D, theta=theta) - np.testing.assert_allclose( - np.array(jax_pe), - torch_pe.numpy(), + chex.assert_trees_all_close( + jax_pe, + torch2jax(torch_pe), rtol=1e-5, atol=1e-5, ) @@ -48,31 +53,44 @@ def test_apply_rope(self): theta = 10000 # Inputs - np_q = np.random.randn(B, H, L, D).astype(np.float32) - np_k = np.random.randn(B, H, L, D).astype(np.float32) - - jax_q = jnp.array(np_q, dtype=jnp.float32) - jax_k = jnp.array(np_k, dtype=jnp.float32) + jax_q = jax.random.normal(key=jax.random.PRNGKey(42), shape=(B, H, L, D)) + jax_k = jax.random.normal(key=jax.random.PRNGKey(42), shape=(B, H, L, D)) - torch_q = torch.from_numpy(np_q).to(torch.float32) - torch_k = torch.from_numpy(np_k).to(torch.float32) + torch_q = torch.from_numpy(jax_q.__array__()).to(torch.float32) + torch_k = torch.from_numpy(jax_k.__array__()).to(torch.float32) - np.testing.assert_allclose(np.array(jax_q), torch_q.numpy()) - np.testing.assert_allclose(np.array(jax_k), torch_k.numpy()) + chex.assert_trees_all_close( + jax_q, + torch2jax(torch_q), + rtol=1e-5, + atol=1e-5, + ) + chex.assert_trees_all_close( + jax_k, + torch2jax(torch_k), + rtol=1e-5, + atol=1e-5, + ) # Position indices (e.g., positions in the sequence) - np_positions = np.random.randn(1, L).astype(np.float32) - torch_positions = torch.from_numpy(np_positions).to(torch.float32) - jax_positions = jnp.array(np_positions, dtype=jnp.float32) + jax_positions = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(B, L), dtype=jnp.float32 + ) + torch_positions = torch.from_numpy(jax_positions.__array__()).to(torch.float32) - np.testing.assert_allclose(np.array(jax_positions), torch_positions.numpy()) + chex.assert_trees_all_close( + jax_positions, + torch2jax(torch_positions), + rtol=1e-5, + atol=1e-5, + ) torch_pe = torch_rope(pos=torch_positions, dim=(3072 // 24), theta=theta) jax_pe = jax_rope(pos=jax_positions, dim=(3072 // 24), theta=theta) - np.testing.assert_allclose( - np.array(jax_pe), - torch_pe.numpy(), + chex.assert_trees_all_close( + jax_pe, + torch2jax(torch_pe), rtol=1e-5, atol=1e-5, ) @@ -90,51 +108,16 @@ def test_apply_rope(self): xq=jax_q, xk=jax_k, freqs_cis=jax_pe ) - np.testing.assert_allclose( - np.array(jax_q_rotated), - torch_q_rotated.numpy(), + chex.assert_trees_all_close( + jax_q_rotated, + torch2jax(torch_q_rotated), rtol=1e-5, atol=1e-5, ) - np.testing.assert_allclose( - np.array(jax_k_rotated), - torch_k_rotated.numpy(), + + chex.assert_trees_all_close( + jax_k_rotated, + torch2jax(torch_k_rotated), rtol=1e-5, atol=1e-5, ) - - # def test_attention(self): - # # Generate random inputs - # np_input = np.random.randn(2, 32, 4, 4).astype(np.float32) - # jax_input = jnp.array(np_input, dtype=jnp.float32) - # torch_input = torch.from_numpy(np_input).to(torch.float32) - - # np.testing.assert_allclose(np.array(jax_input), torch_input.numpy()) - - # # Forward pass - # torch_output = torch_downsample(torch_input) - # jax_output = jax_downsample(rearrange(jax_input, "b c h w -> b h w c")) - - # # Assertions - # np.testing.assert_allclose( - # np.array(rearrange(jax_output, "b h w c -> b c h w")), - # torch_output.detach().numpy(), - # rtol=1e-5, - # atol=1e-5, - # ) - - # @pytest.mark.xfail - # def test_attention(self): - # pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0) - # pos = jnp.repeat(pos, self.batch_size, axis=0) - - # freqs_cis = rope(pos, self.dim, self.theta) - # attention_output = attention(self.q, self.k, self.v, freqs_cis) - - # expected_shape = (self.batch_size, self.seq_len, self.num_heads * self.dim) - - # self.assertEqual( - # attention_output.shape, - # expected_shape, - # "attention function output shape is incorrect", - # ) diff --git a/tests/test_model.py b/tests/test_model.py index 78cc247..18343ec 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,4 +1,4 @@ -import numpy as np +import chex import pytest from flax import nnx from jax import numpy as jnp @@ -6,7 +6,7 @@ from jflux.model import Flux, FluxParams -class ModelTestCase(np.testing.TestCase): +class ModelTestCase(chex.TestCase): @pytest.mark.skip def test_model(self): # Initialize diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 592c525..5f901c1 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -1,6 +1,5 @@ import chex import jax -import numpy as np import torch from flux.sampling import get_noise as torch_get_noise diff --git a/tests/utils.py b/tests/utils.py index 583ab1d..25643b1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,6 +2,7 @@ import torch from chex import Array + __all__ = ["torch2jax"] diff --git a/uv.lock b/uv.lock index 0930d42..efc549b 100644 --- a/uv.lock +++ b/uv.lock @@ -331,7 +331,7 @@ wheels = [ [[package]] name = "jflux" version = "0.1.0" -source = { virtual = "." } +source = { editable = "." } dependencies = [ { name = "einops" }, { name = "fire" }, @@ -355,7 +355,6 @@ requires-dist = [ { name = "fire", specifier = ">=0.6.0" }, { name = "flax", specifier = ">=0.9.0" }, { name = "jax", specifier = ">=0.4.31" }, - { name = "jflux", virtual = "." }, { name = "mypy", specifier = ">=1.11.2" }, { name = "pillow", specifier = ">=10.4.0" }, { name = "ruff", specifier = ">=0.6.3" },