Skip to content

Commit

Permalink
start porting
Browse files Browse the repository at this point in the history
  • Loading branch information
ariG23498 committed Sep 3, 2024
1 parent 1a00d0c commit 3dfd926
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 18 deletions.
1 change: 0 additions & 1 deletion src/jflux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from jflux.util import (
configs,
embed_watermark,
load_ae,
load_clip,
load_flow_model,
Expand Down
2 changes: 1 addition & 1 deletion src/jflux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import Tensor, nn

from flux.modules.layers import (
from jflux.modules.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
Expand Down
4 changes: 2 additions & 2 deletions src/jflux/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __call__(self, x: Array) -> Array:
return self.out_layer(self.silu(self.in_layer(x)))


class RMSNorm(jnp.nnx.Module):
class RMSNorm(nnx.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nnx.Param(jnp.ones(dim))
Expand All @@ -71,7 +71,7 @@ def __call__(self, x: Array):
return (x * rrms).to(dtype=x_dtype) * self.scale


class QKNorm(jnp.nnx.Module):
class QKNorm(nnx.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = RMSNorm(dim)
Expand Down
8 changes: 4 additions & 4 deletions src/jflux/modules/conditioner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import Tensor, nn
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from transformers import FlaxCLIPTextModel, CLIPTokenizer, FlaxT5EncoderModel, T5Tokenizer


class HFEmbedder(nn.Module):
Expand All @@ -13,15 +13,15 @@ def __init__(self, version: str, max_length: int, **hf_kwargs):
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
version, max_length=max_length
)
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
self.hf_module: FlaxCLIPTextModel = FlaxCLIPTextModel.from_pretrained(
version, **hf_kwargs
)
else:
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
version, max_length=max_length
)
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
version, **hf_kwargs
self.hf_module: FlaxT5EncoderModel = FlaxT5EncoderModel.from_pretrained(
version, from_pt=True, **hf_kwargs
)

self.hf_module = self.hf_module.eval().requires_grad_(False)
Expand Down
20 changes: 10 additions & 10 deletions src/jflux/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from dataclasses import dataclass

import torch
from jax import numpy as jnp
from einops import rearrange
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_sft
Expand Down Expand Up @@ -102,7 +102,7 @@ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:


def load_flow_model(
name: str, device: str | torch.device = "cuda", hf_download: bool = True
name: str, device: str | jnp.device = "cuda", hf_download: bool = True
):
# Loading Flux
print("Init model")
Expand All @@ -115,8 +115,8 @@ def load_flow_model(
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)

with torch.device("meta" if ckpt_path is not None else device):
model = Flux(configs[name].params).to(torch.bfloat16)
with jnp.device("meta" if ckpt_path is not None else device):
model = Flux(configs[name].params).to(jnp.bfloat16)

if ckpt_path is not None:
print("Loading checkpoint")
Expand All @@ -127,21 +127,21 @@ def load_flow_model(
return model


def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
def load_t5(device: str | jnp.device = "cuda", max_length: int = 512) -> HFEmbedder:
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
return HFEmbedder(
"google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16
"google/t5-v1_1-xxl", max_length=max_length, dtype=jnp.bfloat16
).to(device)


def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
def load_clip(device: str | jnp.device = "cuda") -> HFEmbedder:
return HFEmbedder(
"openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16
"openai/clip-vit-large-patch14", max_length=77, torch_dtype=jnp.bfloat16
).to(device)


def load_ae(
name: str, device: str | torch.device = "cuda", hf_download: bool = True
name: str, device: str | jnp.device = "cuda", hf_download: bool = True
) -> AutoEncoder:
ckpt_path = configs[name].ae_path
if (
Expand All @@ -154,7 +154,7 @@ def load_ae(

# Loading the autoencoder
print("Init AE")
with torch.device("meta" if ckpt_path is not None else device):
with jnp.device("meta" if ckpt_path is not None else device):
ae = AutoEncoder(configs[name].ae_params)

if ckpt_path is not None:
Expand Down

0 comments on commit 3dfd926

Please sign in to comment.