Skip to content

Commit

Permalink
chore: t5 and clip are flax models now
Browse files Browse the repository at this point in the history
  • Loading branch information
ariG23498 authored and SauravMaheshkar committed Oct 11, 2024
1 parent 1fc186f commit 6115c44
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 26 deletions.
3 changes: 2 additions & 1 deletion jflux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import time
from dataclasses import dataclass
from glob import iglob
import torch
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -14,11 +13,13 @@
from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from jflux.util import configs, load_ae, load_clip, load_flow_model, load_t5


def torch2jax(tensor):
tensor = tensor.float().numpy()
tensor = jnp.array(tensor, dtype=jnp.bfloat16)
return tensor


@dataclass
class SamplingOptions:
prompt: str
Expand Down
34 changes: 16 additions & 18 deletions jflux/modules/conditioner.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,49 @@
# Note: This is a torch module not a Jax module
import jax.numpy as jnp
from flax import nnx
from chex import Array
from torch import nn
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from transformers import (
FlaxCLIPTextModel,
CLIPTokenizer,
FlaxT5EncoderModel,
T5Tokenizer,
)


class HFEmbedder(nn.Module):
class HFEmbedder(nnx.Module):
def __init__(self, version: str, max_length: int, **hf_kwargs):
super().__init__()
self.is_clip = version.startswith("openai")
self.is_clip = version.split("/")[1].startswith("clip")
self.max_length = max_length
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"

if self.is_clip:
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
version, max_length=max_length
)
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
self.hf_module: FlaxCLIPTextModel = FlaxCLIPTextModel.from_pretrained(
version, **hf_kwargs
)
else:
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
version, max_length=max_length
)
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
self.hf_module: FlaxT5EncoderModel = FlaxT5EncoderModel.from_pretrained(
version, **hf_kwargs
)

self.hf_module = self.hf_module.eval().requires_grad_(False)

def forward(self, text: list[str]) -> Array:
def __call__(self, text: list[str]) -> Array:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
return_tensors="np",
)

outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
input_ids=batch_encoding["input_ids"],
attention_mask=None,
output_hidden_states=False,
)
torch_outputs = outputs[self.output_key]

jax_outputs = jnp.array(torch_outputs.cpu().float(), dtype=jnp.bfloat16)
return jax_outputs
outputs = outputs[self.output_key]
return outputs
15 changes: 8 additions & 7 deletions jflux/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from dataclasses import dataclass

import torch # need for t5 and clip
import torch # need for torch 2 jax
from flax import nnx
from huggingface_hub import hf_hub_download
import jax
Expand All @@ -13,6 +13,7 @@
from jflux.modules.conditioner import HFEmbedder
from jflux.port import port_autoencoder, port_flux


def torch2jax(torch_tensor):
intermediate_tensor = torch_tensor.to(torch.float32)
jax_tensor = jnp.array(intermediate_tensor, dtype=jnp.bfloat16)
Expand Down Expand Up @@ -144,17 +145,17 @@ def load_flow_model(name: str, hf_download: bool = True) -> Flux:


def load_t5() -> HFEmbedder:
device = "cuda" if torch.cuda.is_available() else "cpu"
return HFEmbedder(
"google/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16
).to(device)
"ariG23498/t5-v1-1-xxl-flax",
max_length=512,
)


def load_clip() -> HFEmbedder:
device = "cuda" if torch.cuda.is_available() else "cpu"
return HFEmbedder(
"openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16
).to(device)
"ariG23498/clip-vit-large-patch14-text-flax",
max_length=77,
)


def load_ae(name: str, hf_download: bool = True) -> AutoEncoder:
Expand Down

0 comments on commit 6115c44

Please sign in to comment.