-
Notifications
You must be signed in to change notification settings - Fork 125
/
utils.py
36 lines (29 loc) · 1.33 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from modules.autoencoder import AutoEncoder, AutoEncoderParams
from modules.conditioner import HFEmbedder
from safetensors.torch import load_file as load_sft
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
def load_clap(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
return HFEmbedder("laion/larger_clap_music", max_length=256, torch_dtype=torch.bfloat16).to(device)
def load_ae(ckpt_path, device: str | torch.device = "cuda",) -> AutoEncoder:
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
# Loading the autoencoder
ae = AutoEncoder(ae_params)
sd = load_sft(ckpt_path,)
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
ae.to(device)
return ae