From 38a55d6a8b1ccbfdd8a924f0fee04b09d959250a Mon Sep 17 00:00:00 2001 From: CIntellifusion <96377955+CIntellifusion@users.noreply.github.com> Date: Fri, 24 May 2024 22:49:27 +0800 Subject: [PATCH] formal vae encoder decoder --- LatentDiffusion/data/data_wrapper.py | 26 +- LatentDiffusion/models/ae_module.py | 491 +++++++++++++++++++++++++++ LatentDiffusion/vae.py | 145 ++++---- 3 files changed, 585 insertions(+), 77 deletions(-) create mode 100644 LatentDiffusion/models/ae_module.py diff --git a/LatentDiffusion/data/data_wrapper.py b/LatentDiffusion/data/data_wrapper.py index c8dc663..95e9a6b 100644 --- a/LatentDiffusion/data/data_wrapper.py +++ b/LatentDiffusion/data/data_wrapper.py @@ -4,6 +4,7 @@ from torchvision.datasets import MNIST import torch from datasets import load_dataset +imsize = 64 ### data class MNISTDataModule(pl.LightningDataModule): def __init__(self, data_dir="./", batch_size=64,num_workers=63): @@ -62,12 +63,26 @@ def split_dataset(self, dataset, split_ratio=0.2): return train_dataset, val_dataset def prepare_data(self): - self.dataset = load_dataset('nielsr/CelebA-faces') - + self.dataset = load_dataset('nielsr/CelebA-faces')#.map(self.apply_transform) + # self.dataset = self.dataset.with_transform(self.apply_transform) def setup(self, stage=None, transform=None): if stage == 'fit' or stage is None: self.train_dataset, self.val_dataset = self.split_dataset(self.dataset['train'], split_ratio=0.2) - + @staticmethod + def collate_fn(batch): + # for example in batch: + # image = example['image'] + # image.save("/home/haoyu/research/simplemodels/cache/test.jpg") + transform = transforms.Compose([ + transforms.Resize((imsize,imsize)), + transforms.ToTensor(), + # transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) # Normalize images + # transforms.Lambda(lambda x: (x - 0.5) * 2) # unconment + ]) + transformed_batch = torch.stack([transform(example['image']) for example in batch]) + # print("transformerd",transformed_batch.mean(),transformed_batch.min(),transformed_batch.max()) + + return transformed_batch,None def apply_transform(self, example): transform = transforms.Compose([ @@ -76,7 +91,7 @@ def apply_transform(self, example): # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize images # transforms.Lambda(lambda x: (x - 0.5) * 2) # Uncomment to normalize ]) - return transform(example['image']) + return {"image":transform(example['image'])} def train_dataloader(self): return DataLoader(self.train_dataset, @@ -88,5 +103,6 @@ def train_dataloader(self): def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.batch_size, - collate_fn=self.collate_fn, num_workers=self.num_workers, + collate_fn=self.collate_fn, + num_workers=self.num_workers, pin_memory=True) \ No newline at end of file diff --git a/LatentDiffusion/models/ae_module.py b/LatentDiffusion/models/ae_module.py new file mode 100644 index 0000000..1ad2a7b --- /dev/null +++ b/LatentDiffusion/models/ae_module.py @@ -0,0 +1,491 @@ +import torch +from torch import nn +import math +import numpy as np +""" + A simple implementation of Gaussian MLP Encoder and Decoder +""" + +class SimpleEncoder(nn.Module): + + def __init__(self, input_dim, hidden_dim, latent_dim): + super(SimpleEncoder, self).__init__() + + self.FC_input = nn.Linear(input_dim, hidden_dim) + self.FC_input2 = nn.Linear(hidden_dim, hidden_dim) + self.FC_mean = nn.Linear(hidden_dim, latent_dim) + self.FC_var = nn.Linear(hidden_dim, latent_dim) + + self.LeakyReLU = nn.LeakyReLU(0.2) + + + def forward(self, x): + h_ = self.LeakyReLU(self.FC_input(x)) + h_ = self.LeakyReLU(self.FC_input2(h_)) + mean = self.FC_mean(h_) + log_var = self.FC_var(h_) # encoder produces mean and log of variance + # (i.e., parateters of simple tractable normal distribution "q" + + return mean, log_var + + +class SimpleDecoder(nn.Module): + def __init__(self, latent_dim, hidden_dim, output_dim): + super(SimpleDecoder, self).__init__() + self.FC_hidden = nn.Linear(latent_dim, hidden_dim) + self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim) + self.FC_output = nn.Linear(hidden_dim, output_dim) + + self.LeakyReLU = nn.LeakyReLU(0.2) + + def forward(self, x): + h = self.LeakyReLU(self.FC_hidden(x)) + h = self.LeakyReLU(self.FC_hidden2(h)) + + x_hat = torch.sigmoid(self.FC_output(h)) + return x_hat + + +""" + A more complex implementation of Resnet Encoder with Attention + reference: lvdm/moddules/ae_modules.py +""" +def Normalize(in_channels,num_groups=32): + return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-5) + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + +class AttnBlock(nn.Module): + # attention based on conv net + def __init__(self,in_channels): + super().__init__() + self.in_channels = in_channels + self.norm = Normalize(in_channels) + # convlutional layers of kernelsize 1 ,stride 1 , padding 0 means + # a linear layer in the shape of [H,W] + self.q = nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) # also called FFN + + def forward(self,x): + h_ = x + h_ = self.norm(h_)# would this be different to h_=self.norm(x)? + + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute scale dot product attention + b,c,h,w = q.shape + q = q.reshape(b,c,-1).permute(0,2,1) # b,h*w,c + k = k.reshape(b,c,-1) # b,c,h*w + w_ = torch.bmm(q,k)# b,h*w,c @ b,c,h*w -> b,h*w,h*w + w_ = w_ * (int(c)**(-0.5)) # why? + w_ = torch.nn.functional.softmax(w_, dim=2)#b,h*w,h*w + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + + v = v.reshape(b,c,-1) + h_ = torch.bmm(v,w_) + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + return h_ + x + + +def make_attn(in_channels,attn_type="vanilla"): + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type=="none": + return nn.Identity(in_channels) + else: + raise NotImplementedError(f"Attention type {attn_type} is not implemented") + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + self.in_channels = in_channels + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + self.in_channels = in_channels + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + +class Encoder(nn.Module): + # downsample blocks : resblock+attention + # mid : resblock+attention + def __init__(self, *, ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # down sampling + self.conv_in = nn.Conv2d(in_channels, + self.ch, + kernel_size = 3, + stride = 1, + padding=1) + + cur_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + attn = nn.ModuleList() + block = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if cur_res in attn_resolutions: + attn.append(make_attn(block_in,attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in,resamp_with_conv) + cur_res = cur_res // 2 + self.down.append(down) + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in,attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.norm_out = Normalize(block_in) + self.conv_mean = nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + self.conv_var = nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + def forward(self,x): + temb = None + + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + h = hs[-1] + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h,temb) + if len(self.down[i_level].attn)>0: + h = self.down[i_level].attn[i_block](h) + if i_level != self.num_resolutions-1: + h = self.down[i_level].downsample(h) + hs.append(h) + # mid + h = hs[-1] + h = self.mid.block_1(h,temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h,temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + mean = self.conv_mean(h) + log_var = self.conv_var(h) + return mean,log_var + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + + + # compute + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("AE working on z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # print(f'decoder-input={z.shape}') + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + # print(f'decoder-conv in feat={h.shape}') + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + # print(f'decoder-mid feat={h.shape}') + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + # print(f'decoder-up feat={h.shape}') + if i_level != 0: + h = self.up[i_level].upsample(h) + # print(f'decoder-upsample feat={h.shape}') + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + # print(f'decoder-conv_out feat={h.shape}') + if self.tanh_out: + h = torch.tanh(h) + else: + h = torch.sigmoid(h) + return h + def sample(self,n_sample): + z = torch.randn(n_sample,*self.z_shape[1:]) + return self.forward(z) + + +if __name__ == '__main__': + resolution = 64 + in_channels = 3 + encoder = Encoder(ch=256, + resolution=resolution, + in_channels=in_channels, + ch_mult=(1,2,4,8), + num_res_blocks=2, + attn_resolutions=(16,), + dropout=0.0, + resamp_with_conv=True, + z_channels=128, + double_z=True, + use_linear_attn=False, + use_checkpoint=False).to("cuda") + decoder = Decoder(ch=256, + out_ch=3, + resolution=resolution, + in_channels=in_channels, + ch_mult=(1,2,4,8), + num_res_blocks=2, + attn_resolutions=(16,), + dropout=0.0, + resamp_with_conv=True, + z_channels=256, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + use_checkpoint=False).to("cuda") + x = torch.randn(1,3,64,64).to("cuda") + z = torch.randn(1,256,8,8).to("cuda") + print(encoder(x)[0].shape) + print(encoder(x)[1].shape) + print(decoder(z).shape) \ No newline at end of file diff --git a/LatentDiffusion/vae.py b/LatentDiffusion/vae.py index 85d41b4..70a5b39 100644 --- a/LatentDiffusion/vae.py +++ b/LatentDiffusion/vae.py @@ -14,59 +14,43 @@ import pytorch_lightning as pl import torch.optim as optim from data.data_wrapper import MNISTDataModule,CelebDataModule +from models.ae_module import SimpleEncoder,SimpleDecoder +from models.ae_module import Encoder,Decoder - - -""" - A simple implementation of Gaussian MLP Encoder and Decoder -""" - -class SimpleEncoder(nn.Module): - - def __init__(self, input_dim, hidden_dim, latent_dim): - super(SimpleEncoder, self).__init__() - - self.FC_input = nn.Linear(input_dim, hidden_dim) - self.FC_input2 = nn.Linear(hidden_dim, hidden_dim) - self.FC_mean = nn.Linear(hidden_dim, latent_dim) - self.FC_var = nn.Linear (hidden_dim, latent_dim) - - self.LeakyReLU = nn.LeakyReLU(0.2) - - - def forward(self, x): - h_ = self.LeakyReLU(self.FC_input(x)) - h_ = self.LeakyReLU(self.FC_input2(h_)) - mean = self.FC_mean(h_) - log_var = self.FC_var(h_) # encoder produces mean and log of variance - # (i.e., parateters of simple tractable normal distribution "q" - - return mean, log_var - - -class SimpleDecoder(nn.Module): - def __init__(self, latent_dim, hidden_dim, output_dim): - super(SimpleDecoder, self).__init__() - self.FC_hidden = nn.Linear(latent_dim, hidden_dim) - self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim) - self.FC_output = nn.Linear(hidden_dim, output_dim) - - self.LeakyReLU = nn.LeakyReLU(0.2) - - def forward(self, x): - h = self.LeakyReLU(self.FC_hidden(x)) - h = self.LeakyReLU(self.FC_hidden2(h)) - - x_hat = torch.sigmoid(self.FC_output(h)) - return x_hat - class VAE(nn.Module): - def __init__(self, x_dim, hidden_dim, latent_dim,device): + def __init__(self, resolution,in_channels,device): super(VAE, self).__init__() - self.encoder = SimpleEncoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim) - self.decoder = SimpleDecoder(latent_dim=latent_dim, hidden_dim=hidden_dim, output_dim=x_dim) + print("resolution:",resolution,"in_channels:",in_channels) + self.encoder = Encoder( + ch=128, + resolution=resolution, + in_channels=in_channels, + ch_mult=(1,2,4,8), + num_res_blocks=2, + attn_resolutions=(16,), + dropout=0.0, + resamp_with_conv=True, + z_channels=128, + double_z=False, + use_linear_attn=False, + use_checkpoint=False) + self.decoder = Decoder(ch=128, + out_ch=in_channels, + resolution=resolution, + in_channels=in_channels, + ch_mult=(1,2,4,8), + num_res_blocks=2, + attn_resolutions=(16,), + dropout=0.0, + resamp_with_conv=True, + z_channels=128, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + use_checkpoint=False) + self.device = device def reparameterization(self, mean, var): @@ -76,11 +60,15 @@ def reparameterization(self, mean, var): def encode(self,x): mean, log_var = self.encoder(x) + # print("reparameter: ",mean.shape,log_var.shape,x.shape); z = self.reparameterization(mean, torch.exp(0.5 * log_var)) + # print("z latent:", z.shape);exit() return z def decode(self,z): return self.decoder(z) - + + def sample(self,n_sample): + self.decoder.sample(n_sample) def forward(self, x): mean, log_var = self.encoder(x) z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var) @@ -89,9 +77,17 @@ def forward(self, x): return x_hat, mean, log_var def loss_fn(self,x, x_hat, mean, log_var): + # print("loss fn",x.shape,x_hat.shape,log_var.shape,mean.shape) + x = x.view(x.shape[0],-1) + x_hat = x_hat.view(x_hat.shape[0],-1) + log_var = log_var.view(log_var.shape[0],-1) + mean = mean.view(mean.shape[0],-1) + # print("loss fn",x.shape,x_hat.shape,log_var.shape,mean.shape);#exit() + # print("x max min",x.max(),x.min(),x_hat.max(),x_hat.min());exit() reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum') KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp()) return reproduction_loss + KLD + # return KLD class VAETrainer(pl.LightningModule): @@ -104,10 +100,10 @@ def __init__(self, scheduler = "CosineAnnealingLR", sample_output_dir = "./samples", sample_epoch_interval = 20, + device = "cuda" ): super(VAETrainer, self).__init__() - self.model = VAE(x_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim) - self.lr = lr + self.model = VAE(resolution=imsize,in_channels=channels,device=device) self.save_hyperparameters() # Save hyperparameters for logging image_shape = [channels,imsize,imsize] self.lr = lr @@ -121,7 +117,7 @@ def __init__(self, def forward(self, batch): x, _ = batch - x = x.view(self.batch_size, x_dim) + # x = x.view(self.batch_size, x_dim) x_hat, mean, log_var = self.model(x) loss = self.model.loss_fn(x, x_hat, mean, log_var) return loss @@ -145,8 +141,8 @@ def training_step(self, batch, batch_idx): def sample_images(self, output_dir, n_sample=10, device="cuda", simple_var=True): output_file = os.path.join(output_dir , "generated_images.png") with torch.no_grad(): - noise = torch.randn(n_sample, latent_dim)#.to(DEVICE) - generated_images = self.model.decoder(noise) + # noise = torch.randn(n_sample, latent_dim)#.to(DEVICE) + generated_images = self.model.sample(n_sample) save_image(generated_images.view(n_sample,*self.image_shape),output_file, nrow=5, normalize=True) def on_train_epoch_end(self): @@ -161,17 +157,18 @@ def on_train_epoch_end(self): if __name__=="__main__": dataset_path = '../data' - data_module = MNISTDataModule(data_dir=dataset_path, batch_size=100, num_workers=63) + imsize = 64 + batch_size = 64 + # data_module = MNISTDataModule(data_dir=dataset_path, batch_size=100, num_workers=63) + data_module = CelebDataModule(batch_size=batch_size, + num_workers=63,imsize=imsize) cuda = True DEVICE = torch.device("cuda" if cuda else "cpu") - x_dim = 784 - hidden_dim = 400 - latent_dim = 784 lr = 1e-3 - epochs = 100 - sample_epoch_interval = 10 - sample_output_dir = "./samples" - expname = "vae" + epochs = 40 + sample_epoch_interval = 1 + sample_output_dir = "./sample_vae" + expname = "vae_aemodule_celeb64" from pytorch_lightning.callbacks import ModelCheckpoint checkpoint_callback = ModelCheckpoint( dirpath=f"./checkpoints/{expname}", # 保存 checkpoint 的目录 @@ -182,17 +179,21 @@ def on_train_epoch_end(self): verbose=True ) pretrain_path = "/home/haoyu/research/simplemodels/LatentDiffusion/checkpoints/vae/model-epoch=38-val_loss=10231.90039.ckpt" - model = VAETrainer(batch_size=100, - lr=lr,imsize=28, - num_workers=63, - sample_output_dir=sample_output_dir, - sample_epoch_interval=sample_epoch_interval) + pretrain_path = None + model = VAETrainer(batch_size=batch_size, + channels=3, + lr=lr,imsize=imsize, + num_workers=63, + sample_output_dir=sample_output_dir, + sample_epoch_interval=sample_epoch_interval) - trainer = pl.Trainer(gpus=1 if cuda else 0, - max_epochs=epochs, - logger=pl.loggers.TensorBoardLogger("logs/", name=expname), - callbacks=[checkpoint_callback], - ) + trainer = pl.Trainer( + accelerator = "gpu", + devices=1, + max_epochs=epochs, + logger=pl.loggers.TensorBoardLogger("logs/", name=expname), + callbacks=[checkpoint_callback], + ) trainer.fit(model=model,datamodule=data_module,ckpt_path=pretrain_path)