Skip to content

Commit

Permalink
Merge pull request #48 from laksjdjf/speedtest
Browse files Browse the repository at this point in the history
speedtest hoka
  • Loading branch information
laksjdjf authored Apr 17, 2024
2 parents 65c46f4 + 65028e3 commit 683d977
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 6 deletions.
45 changes: 45 additions & 0 deletions modules/dummy/dummy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from torch.utils.data import Dataset
import torch

class DummyDataset(Dataset):
def __init__(
self,
text_model,
batch_size = 1,
size = (512, 512),
num_batch = 100,
cache_latent = False,
cache_text_emb = False,
):
self.batch_size = batch_size
self.width, self.height = size
self.num_batch = num_batch
self.cache_latent = cache_latent
self.cache_text_emb = cache_text_emb

self.sdxl = text_model.sdxl

def __len__(self):
return self.num_batch

def __getitem__(self, i):

batch = {}
if self.cache_latent:
batch["latents"] = torch.randn(self.batch_size, 4, self.height//8, self.width//8)
else:
batch["images"] = torch.randn(self.batch_size, 3, self.height, self.width)

if self.sdxl:
size_list = [self.height, self.width, 0, 0, self.height, self.width]
batch["size_condition"] = torch.tensor(size_list).repeat(self.batch_size, 1)

if self.cache_text_emb:
dim = 2048 if self.sdxl else 768 # sd2? siranai ko desu ne
batch["encoder_hidden_states"] = torch.randn(self.batch_size, 77, dim)
if self.sdxl:
batch["pooled_outputs"] = torch.randn(self.batch_size, dim)
else:
batch["captions"] = ["" for _ in range(self.batch_size)]

return batch
7 changes: 4 additions & 3 deletions modules/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, config, diffusion:DiffusionModel, text_model:TextModel, vae:A
self.diffusers_scheduler = scheduler # モデルのセーブ次にのみ利用
self.scheduler = BaseScheduler(scheduler.config.prediction_type == "v_prediction")
self.sdxl = text_model.sdxl
self.scaling_factor = 0.13025 if self.sdxl else 0.18215

if config is not None and config.merging_loras:
for lora in config.merging_loras:
Expand Down Expand Up @@ -218,10 +219,10 @@ def prepare_lr_scheduler(self, total_steps):

def loss(self, batch):
if "latents" in batch:
latents = batch["latents"].to(self.device) * self.vae.scaling_factor
latents = batch["latents"].to(self.device) * self.scaling_factor
else:
with torch.autocast("cuda", dtype=self.vae_dtype), torch.no_grad():
latents = self.vae.encode(batch['images'].to(self.device)).latent_dist.sample() * self.vae.scaling_factor
latents = self.vae.encode(batch['images'].to(self.device)).latent_dist.sample() * self.scaling_factor

self.batch_size = latents.shape[0] # stepメソッドでも使う

Expand Down Expand Up @@ -316,7 +317,7 @@ def sample(
latents = torch.zeros(batch_size, 4, height // 8, width // 8, device=self.device, dtype=self.autocast_dtype)
else:
with torch.autocast("cuda", dtype=self.vae_dtype):
latents = self.encode_latents(images) * self.vae.scaling_factor
latents = self.encode_latents(images) * self.scaling_factor
latents.to(dtype=self.autocast_dtype)

noise = torch.randn_like(latents)
Expand Down
20 changes: 17 additions & 3 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def get_weight(self, multiplier=None):

class LoRAModule(BaseModule):

def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, state_dict=None, rank=4, alpha=1):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, state_dict=None, rank=4, alpha=1, forward_mode="sequential"):
super().__init__()
self.lora_name = lora_name
self.forward_mode = forward_mode

if state_dict is not None:
up_weight = state_dict[f"{lora_name}.lora_up.weight"]
Expand All @@ -55,6 +56,9 @@ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, state
self.lora_down = torch.nn.Linear(in_dim, rank, bias=False)
self.lora_up = torch.nn.Linear(rank, out_dim, bias=False)

self.functional = torch.nn.functional.linear
self.functional_args = {}

elif 'Conv' in org_module.__class__.__name__: # ["Conv2d", "LoRACompatibleConv"]
in_dim = org_module.in_channels
out_dim = org_module.out_channels
Expand All @@ -70,6 +74,12 @@ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, state
in_dim, self.rank, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(
self.rank, out_dim, (1, 1), (1, 1), bias=False)

self.functional = torch.nn.functional.conv2d
self.functional_args = {
"stride": stride,
"padding": padding,
}

self.shape = org_module.weight.shape

Expand Down Expand Up @@ -108,5 +118,9 @@ def lora_forward(self, x):
def forward(self, x, scale = None):
if self.multiplier == 0.0:
return self.org_forward(x)
else:
return self.org_forward(x) + self.lora_forward(x)
if self.forward_mode == "sequential":
return self.org_forward(x) + self.lora_forward(x)
elif self.forward_mode == "merge":
weight = self.org_module[0].state_dict()["weight"]
bias = None if "bias" not in self.org_module[0].state_dict() else self.org_module[0].state_dict()["bias"]
return self.functional(x, weight + self.get_weight(), bias, **self.functional_args)
5 changes: 5 additions & 0 deletions networks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def __init__(
te2_keys = [key for key in keys if LORA_PREFIX_TEXT_ENCODER_2 in key]

self.module = get_attr_from_config(module)

if hasattr(conv_module_args, "same") and conv_module_args.same:
conv_module_args = module_args
if hasattr(text_module_args, "same") and text_module_args.same:
text_module_args = module_args

# unetのloraを作る
self.unet_modules = []
Expand Down
125 changes: 125 additions & 0 deletions speedtest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from omegaconf import OmegaConf
import sys
import math
from accelerate.utils import set_seed
from modules.utils import get_attr_from_config, collate_fn
from modules.config import Config
from tqdm import tqdm
import logging
import subprocess
import time
import json
import pandas as pd
from itertools import product
import torch
import gc

logger = logging.getLogger("テストちゃん")

def get_gpu_memory_usage():
cmd = ['nvidia-smi', '--query-gpu=memory.used', '--format=csv,noheader,nounits']
result = subprocess.run(cmd, stdout=subprocess.PIPE)
return int(result.stdout.decode('utf-8').strip())

def setattr_recursive(obj, key, value):
if "." in key:
key, rest = key.split(".", 1)
setattr_recursive(getattr(obj, key), rest, value)
else:
setattr(obj, key, value)

def main(config):

set_seed(config.main.seed)
logger.info(f"シードは{config.main.seed}だよ!")

logger.info(f"モデルを{config.main.model_path}からロードしちゃうよ!")
trainer_cls = get_attr_from_config(config.trainer.module)
trainer = trainer_cls.from_pretrained(config.main.model_path, config.main.sdxl, config.main.clip_skip, config.trainer)

dataset_cls = get_attr_from_config(config.dataset.module)
dataset = dataset_cls(trainer.text_model, **config.dataset.args)

dataloder_cls = get_attr_from_config(config.dataloader.module)
dataloader = dataloder_cls(dataset, collate_fn=collate_fn, **config.dataloader.args)

trainer.prepare_modules_for_training()
trainer.prepare_network(config.network)
trainer.prepare_controlnet(config.controlnet)
trainer.apply_module_settings()

trainer.prepare_optimizer()

steps_per_epoch = len(dataloader)
total_steps = config.main.steps or steps_per_epoch * config.main.epochs
total_epochs = config.main.epochs or math.floor(total_steps / steps_per_epoch)
logger.info(f"トータルのステップ数は{total_steps}だよ!")

trainer.prepare_lr_scheduler(total_steps)

peek_memory = get_gpu_memory_usage()
current_step = 0

progress_bar = None
for epoch in range(total_epochs):
for batch in dataloader:
if progress_bar is None:
start_time = time.time()
progress_bar = tqdm(total=total_steps, desc="Training")
logs = trainer.step(batch)
peek_memory = max(peek_memory, get_gpu_memory_usage())
logs.update({"peek_memory": peek_memory})
progress_bar.update(1)
progress_bar.set_postfix(logs)
current_step += 1

if current_step == total_steps:
logger.info(f"トレーニングが終わったよ!")
end_time = time.time()
seconds = end_time - start_time
samples_per_second = total_steps*dataset.batch_size / seconds
print(f"トータルの時間は{seconds:02}秒だよ!")
print(f"VRAMのピークは{peek_memory}MBだよ!")
print(f"1秒あたりのサンプル数は{samples_per_second}だよ!")
del trainer.diffusion.unet, trainer.vae, trainer.text_model
del trainer
gc.collect()
torch.cuda.empty_cache()
return seconds, total_steps, samples_per_second, peek_memory

logger.info(f"エポック{epoch+1}が終わったよ!")

if __name__ == "__main__":
base_config = OmegaConf.load(sys.argv[1])
base_config = OmegaConf.merge(OmegaConf.structured(Config), base_config)

logging.basicConfig(level=logging.WARNING)
print(OmegaConf.to_yaml(base_config))

if len(sys.argv) == 3:
with open(sys.argv[2], "r") as f:
valiation = json.load(f)

keys = list(valiation.keys())
values = list(valiation.values())
columns = [key.split(".")[-1] for key in keys]+["time", "steps", "samples/s", "vram", ]
df = pd.DataFrame(columns=columns)

for settings in product(*values):
print({keys[i]: setting for i, setting in enumerate(settings)})
for i, setting in enumerate(settings):
setattr_recursive(base_config, keys[i], setting)

try:
seconds, steps, samples_par_second, memory = main(base_config)
except Exception as e:
print(e)
seconds, steps, samples_par_second, memory = 0, 0, 0, 0

data = list(settings) + [seconds, steps, samples_par_second, memory]
df.loc[len(df)] = data

df.to_csv("speed_test.csv")

else:
main(base_config)

0 comments on commit 683d977

Please sign in to comment.