diff --git a/src/jflux/cli.py b/src/jflux/cli.py index 2877c00..0643ef3 100644 --- a/src/jflux/cli.py +++ b/src/jflux/cli.py @@ -13,7 +13,6 @@ from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack from jflux.util import ( configs, - embed_watermark, load_ae, load_clip, load_flow_model, diff --git a/src/jflux/model.py b/src/jflux/model.py index c7f2b35..51f0e5e 100644 --- a/src/jflux/model.py +++ b/src/jflux/model.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn -from flux.modules.layers import ( +from jflux.modules.layers import ( DoubleStreamBlock, EmbedND, LastLayer, diff --git a/src/jflux/modules/__init__.py b/src/jflux/modules/__init__.py index e51c541..603f06d 100644 --- a/src/jflux/modules/__init__.py +++ b/src/jflux/modules/__init__.py @@ -59,7 +59,7 @@ def __call__(self, x: Array) -> Array: return self.out_layer(self.silu(self.in_layer(x))) -class RMSNorm(jnp.nnx.Module): +class RMSNorm(nnx.Module): def __init__(self, dim: int): super().__init__() self.scale = nnx.Param(jnp.ones(dim)) @@ -71,7 +71,7 @@ def __call__(self, x: Array): return (x * rrms).to(dtype=x_dtype) * self.scale -class QKNorm(jnp.nnx.Module): +class QKNorm(nnx.Module): def __init__(self, dim: int): super().__init__() self.query_norm = RMSNorm(dim) diff --git a/src/jflux/modules/conditioner.py b/src/jflux/modules/conditioner.py index c5c3e16..04ef5ee 100644 --- a/src/jflux/modules/conditioner.py +++ b/src/jflux/modules/conditioner.py @@ -1,5 +1,5 @@ from torch import Tensor, nn -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer +from transformers import FlaxCLIPTextModel, CLIPTokenizer, FlaxT5EncoderModel, T5Tokenizer class HFEmbedder(nn.Module): @@ -13,15 +13,15 @@ def __init__(self, version: str, max_length: int, **hf_kwargs): self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( version, max_length=max_length ) - self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( + self.hf_module: FlaxCLIPTextModel = FlaxCLIPTextModel.from_pretrained( version, **hf_kwargs ) else: self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( version, max_length=max_length ) - self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( - version, **hf_kwargs + self.hf_module: FlaxT5EncoderModel = FlaxT5EncoderModel.from_pretrained( + version, from_pt=True, **hf_kwargs ) self.hf_module = self.hf_module.eval().requires_grad_(False) diff --git a/src/jflux/util.py b/src/jflux/util.py index 366d68a..fcf53e3 100644 --- a/src/jflux/util.py +++ b/src/jflux/util.py @@ -1,7 +1,7 @@ import os from dataclasses import dataclass -import torch +from jax import numpy as jnp from einops import rearrange from huggingface_hub import hf_hub_download from safetensors.torch import load_file as load_sft @@ -102,7 +102,7 @@ def print_load_warning(missing: list[str], unexpected: list[str]) -> None: def load_flow_model( - name: str, device: str | torch.device = "cuda", hf_download: bool = True + name: str, device: str | jnp.device = "cuda", hf_download: bool = True ): # Loading Flux print("Init model") @@ -115,8 +115,8 @@ def load_flow_model( ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) - with torch.device("meta" if ckpt_path is not None else device): - model = Flux(configs[name].params).to(torch.bfloat16) + with jnp.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params).to(jnp.bfloat16) if ckpt_path is not None: print("Loading checkpoint") @@ -127,21 +127,21 @@ def load_flow_model( return model -def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: +def load_t5(device: str | jnp.device = "cuda", max_length: int = 512) -> HFEmbedder: # max length 64, 128, 256 and 512 should work (if your sequence is short enough) return HFEmbedder( - "google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16 + "google/t5-v1_1-xxl", max_length=max_length, dtype=jnp.bfloat16 ).to(device) -def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: +def load_clip(device: str | jnp.device = "cuda") -> HFEmbedder: return HFEmbedder( - "openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16 + "openai/clip-vit-large-patch14", max_length=77, torch_dtype=jnp.bfloat16 ).to(device) def load_ae( - name: str, device: str | torch.device = "cuda", hf_download: bool = True + name: str, device: str | jnp.device = "cuda", hf_download: bool = True ) -> AutoEncoder: ckpt_path = configs[name].ae_path if ( @@ -154,7 +154,7 @@ def load_ae( # Loading the autoencoder print("Init AE") - with torch.device("meta" if ckpt_path is not None else device): + with jnp.device("meta" if ckpt_path is not None else device): ae = AutoEncoder(configs[name].ae_params) if ckpt_path is not None: