From 6115c44404785973f33e3a78b6a5d10d0f7aca8b Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Fri, 11 Oct 2024 05:07:54 +0000 Subject: [PATCH] chore: t5 and clip are flax models now --- jflux/cli.py | 3 ++- jflux/modules/conditioner.py | 34 ++++++++++++++++------------------ jflux/util.py | 15 ++++++++------- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/jflux/cli.py b/jflux/cli.py index 0bc9c3c..dbdad05 100644 --- a/jflux/cli.py +++ b/jflux/cli.py @@ -3,7 +3,6 @@ import time from dataclasses import dataclass from glob import iglob -import torch import jax import jax.numpy as jnp import numpy as np @@ -14,11 +13,13 @@ from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack from jflux.util import configs, load_ae, load_clip, load_flow_model, load_t5 + def torch2jax(tensor): tensor = tensor.float().numpy() tensor = jnp.array(tensor, dtype=jnp.bfloat16) return tensor + @dataclass class SamplingOptions: prompt: str diff --git a/jflux/modules/conditioner.py b/jflux/modules/conditioner.py index 2c29675..fdba287 100644 --- a/jflux/modules/conditioner.py +++ b/jflux/modules/conditioner.py @@ -1,14 +1,16 @@ -# Note: This is a torch module not a Jax module -import jax.numpy as jnp +from flax import nnx from chex import Array -from torch import nn -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer +from transformers import ( + FlaxCLIPTextModel, + CLIPTokenizer, + FlaxT5EncoderModel, + T5Tokenizer, +) -class HFEmbedder(nn.Module): +class HFEmbedder(nnx.Module): def __init__(self, version: str, max_length: int, **hf_kwargs): - super().__init__() - self.is_clip = version.startswith("openai") + self.is_clip = version.split("/")[1].startswith("clip") self.max_length = max_length self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" @@ -16,20 +18,18 @@ def __init__(self, version: str, max_length: int, **hf_kwargs): self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( version, max_length=max_length ) - self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( + self.hf_module: FlaxCLIPTextModel = FlaxCLIPTextModel.from_pretrained( version, **hf_kwargs ) else: self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( version, max_length=max_length ) - self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( + self.hf_module: FlaxT5EncoderModel = FlaxT5EncoderModel.from_pretrained( version, **hf_kwargs ) - self.hf_module = self.hf_module.eval().requires_grad_(False) - - def forward(self, text: list[str]) -> Array: + def __call__(self, text: list[str]) -> Array: batch_encoding = self.tokenizer( text, truncation=True, @@ -37,15 +37,13 @@ def forward(self, text: list[str]) -> Array: return_length=False, return_overflowing_tokens=False, padding="max_length", - return_tensors="pt", + return_tensors="np", ) outputs = self.hf_module( - input_ids=batch_encoding["input_ids"].to(self.hf_module.device), + input_ids=batch_encoding["input_ids"], attention_mask=None, output_hidden_states=False, ) - torch_outputs = outputs[self.output_key] - - jax_outputs = jnp.array(torch_outputs.cpu().float(), dtype=jnp.bfloat16) - return jax_outputs + outputs = outputs[self.output_key] + return outputs diff --git a/jflux/util.py b/jflux/util.py index c4dd169..5938093 100644 --- a/jflux/util.py +++ b/jflux/util.py @@ -1,7 +1,7 @@ import os from dataclasses import dataclass -import torch # need for t5 and clip +import torch # need for torch 2 jax from flax import nnx from huggingface_hub import hf_hub_download import jax @@ -13,6 +13,7 @@ from jflux.modules.conditioner import HFEmbedder from jflux.port import port_autoencoder, port_flux + def torch2jax(torch_tensor): intermediate_tensor = torch_tensor.to(torch.float32) jax_tensor = jnp.array(intermediate_tensor, dtype=jnp.bfloat16) @@ -144,17 +145,17 @@ def load_flow_model(name: str, hf_download: bool = True) -> Flux: def load_t5() -> HFEmbedder: - device = "cuda" if torch.cuda.is_available() else "cpu" return HFEmbedder( - "google/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16 - ).to(device) + "ariG23498/t5-v1-1-xxl-flax", + max_length=512, + ) def load_clip() -> HFEmbedder: - device = "cuda" if torch.cuda.is_available() else "cpu" return HFEmbedder( - "openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16 - ).to(device) + "ariG23498/clip-vit-large-patch14-text-flax", + max_length=77, + ) def load_ae(name: str, hf_download: bool = True) -> AutoEncoder: