Skip to content

Commit

Permalink
LatentDiffusion
Browse files Browse the repository at this point in the history
first runnable version
with pretrained vae
but vae need to be modified to cast shape with image like
CIntellifusion committed May 22, 2024
1 parent fbbe556 commit 4cadbd8
Showing 14 changed files with 162 additions and 70 deletions.
Binary file removed LatentDiffusion/VAE/samples/0/generated_images.png
Binary file not shown.
Binary file removed LatentDiffusion/VAE/samples/10/generated_images.png
Binary file not shown.
Binary file removed LatentDiffusion/VAE/samples/20/generated_images.png
Binary file not shown.
Binary file not shown.
108 changes: 72 additions & 36 deletions LatentDiffusion/main.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,8 @@
from data.data_wrapper import MNISTDataModule,CelebDataModule
from models.unet import UNet
from schedulers.ddpm import DDPM
from vae.vae import VAE
from vae import VAE
from util import images2gif
## sorry to use global value
imsize = 32
"""
@@ -60,16 +61,19 @@ def __init__(self,
scheduler = "CosineAnnealingLR",
sample_output_dir = "./samples",
sample_epoch_interval = 20,
vae_config = {"x_dim":784,'hidden_dim':400,"latent_dim":200},
vae_config = {"x_dim":784,'hidden_dim':400,"latent_dim":784,"device":"cuda:0"},
vae_pretrained_path = None
):
super(LatentDiffusion, self).__init__()
self.save_hyperparameters() # Save hyperparameters for logging
image_shape = [channels,imsize,imsize]
#TODO: latent shape
self.latent_shape = None
self.denoiser = UNet( n_steps=N, image_shape=self.latent_shape)
self.latent_shape = [vae_config["latent_dim"]]
self.denoiser = UNet( n_steps=N, image_shape=image_shape)
self.ddpm = DDPM(min_beta=min_beta,max_beta=max_beta,N=N)
self.vae = VAE(**vae_config)
self.vae_config = vae_config
self.config_vae(vae_pretrained_path)
self.criterion = nn.MSELoss()
self.N = N
self.lr = lr
@@ -80,23 +84,24 @@ def __init__(self,

self.sample_output_dir = sample_output_dir
self.sample_epoch_interval = sample_epoch_interval
def forward(self, batch):
# print(batch)
images,_= batch
# print(images.shape)
bs = images.shape[0]
t = torch.randint(0,self.N,(bs,),device = images.device)
eps = torch.randn_like(images,device=images.device)
x_t = self.ddpm.sample_forward(images, t, eps)
# print(images.max(),images.min(),x_t.max(),x_t.min())
eps_theta = self.denoiser(x_t, t.reshape(bs, 1))
# print(t)
# print("training ",eps.max(),x_t.max(),eps_theta.max())
loss = self.criterion(eps,eps_theta)
return loss


def config_vae(self,pretrained_path):
ckpt = torch.load(pretrained_path)
if "state_dict" in ckpt.keys():
ckpt = ckpt["state_dict"]
new_state_dict = {}
for k,v in ckpt.items():
new_state_dict[k.replace("model.","")] = v
self.vae.load_state_dict(new_state_dict)

# freeze parameters
for param in self.vae.parameters():
param.requires_grad = False
# eval mode
self.vae.eval()

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
optimizer = torch.optim.Adam(self.denoiser.parameters(), lr=self.lr)
if self.scheduler == "ReduceLROnPlateau":

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
@@ -131,14 +136,38 @@ def lr_scheduler_step(self, epoch, batch_idx, optimizer,**kwargs):
self.scheduler.step() # Update the scheduler

def AE_encode(self,x):
return self.vae.encode(x)
# x = torch.concat(x)
# print("AE_encode",x.shape)# [128, 1, 28, 28]
bs = x.shape[0]
x = x.view(bs, *self.latent_shape)
return self.vae.encode(x) # encoder posterior ; tensor

def AE_decode(self,x):
return self.vae.decode(x)

def get_input(self,batch):
return self.AE_encode(batch)
images,_= batch
return self.AE_encode(images)

def forward(self, batch):
# print(batch)
latents = batch
bs = latents.shape[0]
# print("forward ae encode output",latents.shape)
latents = latents.reshape(bs,*self.image_shape)
# print(latents.shape)

t = torch.randint(0,self.N,(bs,),device = latents.device)
eps = torch.randn_like(latents,device=latents.device)
x_t = self.ddpm.sample_forward(latents, t, eps)
# print(latents.max(),latents.min(),x_t.max(),x_t.min())
eps_theta = self.denoiser(x_t, t.reshape(bs, 1))
# print(t)
# print("training ",eps.max(),x_t.max(),eps_theta.max())
loss = self.criterion(eps,eps_theta)
return loss
def validation_step(self, batch, batch_idx):
batch = self.get_input(batch)
val_loss = self(batch)
self.log('val_loss', val_loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=True)
return val_loss
@@ -155,17 +184,20 @@ def sample_images(self, output_dir, n_sample=9, device="cuda", simple_var=True):
max_batch_size = 32
self.to(device)
self.denoiser.eval()
name = "generated_images.png"
os.makedirs(output_dir, exist_ok=True)
with torch.no_grad():
for i in range(0, n_sample, max_batch_size):
# shape = (min(max_batch_size, n_sample - i),*self.image_shape)
shape = (min(max_batch_size, n_sample - i),*self.latent_shape)
latents = self.ddpm.sample_backward(shape, self.denoiser, device=device, simple_var=simple_var).detach().cpu()
imgs = self.AE_decoce(latents)
bs = min(max_batch_size, n_sample - i)
shape = (bs,*self.image_shape)
latents = self.ddpm.sample_backward(shape, self.denoiser, device=device, simple_var=simple_var)
imgs = self.AE_decode(latents.view((bs,*self.latent_shape))).detach().cpu()

# print("in sample images: ",imgs.max(),imgs.min())
# imgs = (imgs + 1) / 2 * 255
# imgs = imgs.clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)#.numpy()
output_file = os.path.join(output_dir, "generated_images.png")
output_file = os.path.join(output_dir,name )
channels,h,w = self.image_shape
save_image(imgs.view(n_sample,channels,h,w),output_file, nrow=3, normalize=True)

@@ -174,6 +206,14 @@ def on_train_epoch_end(self):
output_dir = os.path.join(self.sample_output_dir, f'{self.current_epoch+1:05}')
self.sample_images(output_dir=output_dir,n_sample=9,device="cuda",simple_var=True)

# after training , call imagetogif
def on_fit_end(self):
folder = self.sample_output_dir
savepath = os.path.join(folder, "generated_video.gif")
subfolders = sorted(os.listdir(self.sample_output_dir))
name = "generated_images.png"
image_files = sorted([os.path.join(folder,sf,name) for sf in subfolders])
images2gif(image_files,savepath)

### parse args
def parse_args():
@@ -196,8 +236,9 @@ def parse_args():
parser.add_argument('--channels', type=int, default=3, help='channels of image ')
parser.add_argument('--imsize', type=int, default=64, help='image size ')
parser.add_argument('--scheduler', type=str, default="None", help='lr policy')
parser.add_argument('--dataset', type=str, default="celeba", help='dataset')
parser.add_argument('--sample_epoch_interval', type=int, default=10, help='sample interval')
parser.add_argument('--dataset', type=str, default="mnist", help='dataset')
parser.add_argument('--sample_epoch_interval', type=int, default=1, help='sample interval')
parser.add_argument('--vae_ckpt', type=str, default="/home/haoyu/research/simplemodels/LatentDiffusion/checkpoints/vae/model-epoch=81-val_loss=10050.03320.ckpt", help='pretrained vae checkpoint path')
args = parser.parse_args()

return args
@@ -231,7 +272,8 @@ def parse_args():
lr = args.lr,
scheduler=args.scheduler,
sample_output_dir=f"./sample/{expname}",
sample_epoch_interval=args.sample_epoch_interval
sample_epoch_interval=args.sample_epoch_interval,
vae_pretrained_path = args.vae_ckpt
)

# 设置保存 checkpoint 的回调函数
@@ -244,12 +286,7 @@ def parse_args():
verbose=True
)

# pretrain_path = "/data2/wuhaoyu/SimpleDiffusion/UnconditionalDiffusion/checkpoints/model-epoch=443-train_loss=0.00147.ckpt"
# pretrain_path = "/home/haoyu/research/simplemodels/SimpleDiffusion/UnconditionalDiffusion/checkpoints/model-epoch=159-val_loss=0.00454.ckpt"
pretrain_path = "/home/haoyu/research/simplemodels/SimpleDiffusion/UnconditionalDiffusion/checkpoints/randn/model-epoch=351-val_loss=0.01861.ckpt"
pretrain_path = "/home/haoyu/research/simplemodels/SimpleDiffusion/UnconditionalDiffusion/checkpoints/celeb_without_normal/model-epoch=37-val_loss=0.01920.ckpt"
pretrain_path = "/home/haoyu/research/simplemodels/SimpleDiffusion/UnconditionalDiffusion/checkpoints/celeb64/model-epoch=198-val_loss=0.01100.ckpt"
pretrain_path = "/home/haoyu/research/simplemodels/SimpleDiffusion/UnconditionalDiffusion/checkpoints/celeb64/model-epoch=557-val_loss=0.01012.ckpt"
pretrain_path = None
trainer = pl.Trainer(
accelerator="gpu",
devices=args.devices, # 使用一块 GPU 进行训练
@@ -258,7 +295,6 @@ def parse_args():
# progress_bar_refresh_rate=20, # 进度条刷新频率
callbacks=[checkpoint_callback], # 注册 checkpoint 回调函数
)

trainer.fit(model,data_module,ckpt_path = pretrain_path)
else:
ckpt_folder = f"./checkpoints/{expname}"
Binary file not shown.
1 change: 1 addition & 0 deletions LatentDiffusion/models/unet.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

import torch
from torch import nn
from torch.nn import functional as F

class PositionalEncoding(nn.Module):
def __init__(self, max_seq_len: int, d_model: int):
Binary file not shown.
7 changes: 5 additions & 2 deletions LatentDiffusion/schedulers/ddpm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch
from torch import nn

### DDIM scheduler
class DDPM(nn.Module):
@@ -25,10 +27,11 @@ def sample_forward(self,x,t,eps=None):
@torch.no_grad()
def sample_backward(self, image_or_shape,net,device="cuda",simple_var=True):
if isinstance(image_or_shape,torch.Tensor):
x = image_or_shape
x = image_or_shape.to(device)
else:
x = torch.randn(image_or_shape,device=device)
# debug
# print("sample_backward",x.device)
# print(x.max(),x.min(),x.mean())
# for t in range(self.N-1,-1,-1):
# self.sample_backward_step(net, x, t, simple_var)
@@ -48,7 +51,7 @@ def sample_backward_step(self,net,x_t, t,simple_var,use_noise=True,clip_denoised
else:
var = (1-self.alpha_bars_prev[t])/(1-self.alpha_bars[t]) * self.betas[t]
#这个地方还真写错了 randn_like和rand_like不一样wor
noise = torch.randn_like(x_t) * torch.sqrt(var)
noise = torch.randn_like(x_t,device=x_t.device) * torch.sqrt(var)
eps = net(x_t,t_tensor)
# with open("./cache.txt",'a') as f:
# f.write(f"{eps.mean().item()},{eps.max().item()},{eps.min().item()}\n")
56 changes: 56 additions & 0 deletions LatentDiffusion/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
## given one folder, concat the images in them and save them in the same folder
import os,cv2
import numpy as np
from PIL import Image
def concat_images(folder,name="generated_images.jpg",policy='square'):
files = os.listdir(folder)
files.sort()
images=[]
for f in files:
if f[-3:] in ['jpg','png'] and name not in f:
images.append(cv2.imread(os.path.join(folder,f)))
num = len(images)
height, width, layers = images[0].shape
if policy == "square":
# select 9,16,25 images
images = images[:9]
# images = np.concatenate(images,axis=1).reshape(height*3, width*num, layers)
big_image = np.zeros((height * 3, width * 3, 3), dtype=np.uint8)

# 将每张小图片放置到大图像数组中的相应位置
for i in range(3):
for j in range(3):
idx = i * 3 + j
big_image[i*height:(i+1)*height, j*width:(j+1)*width, :] = images[idx]

images =big_image
# print(images.shape,type(images))
else:
images = np.concatenate(images,axis=1).reshape(height, width*num, layers)
cv2.imwrite(os.path.join(folder,name),images)

## image to gifs
import imageio
def images2gif(image_files:list,save_path:str):
gif_frames = []
for file_name in image_files:
# print(file_name)
gif_frames.append(imageio.imread(file_name))
imageio.mimsave(save_path, gif_frames, duration=0.5)

if __name__=="__main__":

folder = "./LatentDiffusion/samples/"
# folder = "./sample/randn/"
# name = "0000_0008.png"
name = "generated_images.png"
subfolders = sorted(os.listdir(folder))
# for sf in subfolders:
# os.system(f"mv {os.path.join(folder,sf)} {os.path.join(folder,f'{int(sf):05d}')}")
# subfolders = sorted(os.listdir(folder))

# for sf in subfolders:
# concat_images(os.path.join(folder,sf),name,'square')

image_files = sorted([os.path.join(folder,sf,name) for sf in subfolders])
images2gif(image_files,os.path.join(folder,"generated_images.gif"))
Loading

0 comments on commit 4cadbd8

Please sign in to comment.