Skip to content

Commit

Permalink
Latent Diffusion Framework
Browse files Browse the repository at this point in the history
this based on unconditional diffusion. this version is not finished and tested
  • Loading branch information
CIntellifusion committed May 11, 2024
1 parent ce9a888 commit d418b6a
Show file tree
Hide file tree
Showing 13 changed files with 1,001 additions and 25 deletions.
112 changes: 88 additions & 24 deletions LatentDiffusion/VAE/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def setup(self, stage=None,transform=None):
if transform is None :
transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize((0.1307,), (0.3081,))
# transforms.Normalize((0.1307,), (0.3081,)) #
])

if stage == 'fit' or stage is None:
Expand All @@ -53,6 +53,72 @@ def train_dataloader(self):

def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size,num_workers=self.num_workers,pin_memory=True)

"""
celeba dataset
"""
from datasets import load_dataset
class CelebDataModule(pl.LightningDataModule):
def __init__(self, batch_size=64,num_workers=63,imsize=32):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.imsize = args.imsize
def split_dataset(self, dataset, split_ratio=0.2):
"""
Divides the dataset into training and validation sets.
Args:
- dataset (datasets.Dataset): The dataset to be divided
- split_ratio (float): The proportion of the validation set, default is 0.2
Returns:
- train_dataset (datasets.Dataset): The divided training set
- val_dataset (datasets.Dataset): The divided validation set
"""
num_val_samples = int(len(dataset) * split_ratio)

val_dataset = dataset.shuffle(seed=42).select(range(num_val_samples))
train_dataset = dataset.shuffle(seed=42).select(range(num_val_samples, len(dataset)))

return train_dataset, val_dataset

def prepare_data(self):
self.dataset = load_dataset('nielsr/CelebA-faces')

def setup(self, stage=None, transform=None):
if stage == 'fit' or stage is None:
self.train_dataset, self.val_dataset = self.split_dataset(self.dataset['train'], split_ratio=0.2)

@staticmethod
def collate_fn(batch):
# for example in batch:
# image = example['image']
# image.save("/home/haoyu/research/simplemodels/cache/test.jpg")
transform = transforms.Compose([
transforms.Resize((imsize,imsize)),
transforms.ToTensor(),
# transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) # Normalize images
# transforms.Lambda(lambda x: (x - 0.5) * 2) # unconment
])
transformed_batch = torch.stack([transform(example['image']) for example in batch])
# print("transformerd",transformed_batch.mean(),transformed_batch.min(),transformed_batch.max())
return transformed_batch,None

def train_dataloader(self):
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
collate_fn=self.collate_fn,
shuffle=True, num_workers=self.num_workers,
pin_memory=True)

def val_dataloader(self):
return DataLoader(self.val_dataset,
batch_size=self.batch_size,
collate_fn=self.collate_fn, num_workers=self.num_workers,
pin_memory=True)


"""
A simple implementation of Gaussian MLP Encoder and Decoder
"""
Expand Down Expand Up @@ -123,25 +189,6 @@ def loss_fn(self,x, x_hat, mean, log_var):
KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())
return reproduction_loss + KLD

# class VAELoss(nn.Module):
# def __init__(self):
# super(VAELoss, self).__init__()
# self.BCE_loss = nn.BCELoss()

# def forward(self, x, x_hat, mean, log_var):
# reproduction_loss = self.BCE_loss(x_hat, x)
# KLD = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
# return reproduction_loss + KLD

# class VAEModel(nn.Module):
# def __init__(self, x_dim, hidden_dim, latent_dim):
# super(VAEModel, self).__init__()
# self.model = VAE(x_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
# self.loss_fn = VAELoss()

# def forward(self, x):
# return self.model(x)


class VAETrainer(pl.LightningModule):
def __init__(self,
Expand Down Expand Up @@ -199,7 +246,7 @@ def sample_images(self, output_dir, n_sample=10, device="cuda", simple_var=True)
save_image(generated_images.view(n_sample,*self.image_shape),output_file, nrow=5, normalize=True)

def on_train_epoch_end(self):
if self.current_epoch % self.sample_epoch_interval==0:
if self.current_epoch+1 % self.sample_epoch_interval==0:
output_dir = os.path.join(self.sample_output_dir, f'{self.current_epoch}')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
Expand All @@ -221,10 +268,27 @@ def on_train_epoch_end(self):
epochs = 30
sample_epoch_interval = 10
sample_output_dir = "./samples"
expname = "vae"
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
dirpath=f"./checkpoints/{expname}", # 保存 checkpoint 的目录
filename="model-{epoch:02d}-{val_loss:.5f}", # checkpoint 文件名格式
monitor="val_loss", # 监控的指标,这里使用验证集损失
mode="min", # 指定监控模式为最小化验证集损失
save_top_k=3, # 保存最好的 3 个 checkpoint
verbose=True
)

trainer = pl.Trainer(gpus=1 if cuda else 0, max_epochs=epochs)
trainer.fit(model=VAETrainer(batch_size=100,lr=lr,imsize=28,num_workers=63,sample_output_dir=sample_output_dir,sample_epoch_interval=sample_epoch_interval), datamodule=data_module)

model = VAETrainer(batch_size=100,
lr=lr,imsize=28,
num_workers=63,
sample_output_dir=sample_output_dir,
sample_epoch_interval=sample_epoch_interval)

trainer = pl.Trainer(gpus=1 if cuda else 0, max_epochs=epochs)
trainer.fit(model=model,
datamodule=data_module)




93 changes: 93 additions & 0 deletions LatentDiffusion/data/data_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@

### data
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir="./", batch_size=64,num_workers=63):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers

def prepare_data(self):
# This method is intended for dataset downloading and preparation
# We will download the MNIST dataset here (only called on 1 GPU in distributed training)
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)

def setup(self, stage=None,transform=None):
# This method is called on every GPU in the distributed setup and should split the data
if transform is None :
transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize((0.1307,), (0.3081,)) #
])

if stage == 'fit' or stage is None:
self.train_dataset = MNIST(self.data_dir, train=True, transform=transform)
self.val_dataset = MNIST(self.data_dir, train=False, transform=transform)


def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,num_workers=self.num_workers,pin_memory=True)

def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size,num_workers=self.num_workers,pin_memory=True)

class CelebDataModule(pl.LightningDataModule):
def __init__(self, batch_size=64,num_workers=63,imsize=32):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.imsize = args.imsize
def split_dataset(self, dataset, split_ratio=0.2):
"""
Divides the dataset into training and validation sets.
Args:
- dataset (datasets.Dataset): The dataset to be divided
- split_ratio (float): The proportion of the validation set, default is 0.2
Returns:
- train_dataset (datasets.Dataset): The divided training set
- val_dataset (datasets.Dataset): The divided validation set
"""
num_val_samples = int(len(dataset) * split_ratio)

val_dataset = dataset.shuffle(seed=42).select(range(num_val_samples))
train_dataset = dataset.shuffle(seed=42).select(range(num_val_samples, len(dataset)))

return train_dataset, val_dataset

def prepare_data(self):
self.dataset = load_dataset('nielsr/CelebA-faces')

def setup(self, stage=None, transform=None):
if stage == 'fit' or stage is None:
self.train_dataset, self.val_dataset = self.split_dataset(self.dataset['train'], split_ratio=0.2)

@staticmethod
def collate_fn(batch):
# for example in batch:
# image = example['image']
# image.save("/home/haoyu/research/simplemodels/cache/test.jpg")
transform = transforms.Compose([
transforms.Resize((imsize,imsize)),
transforms.ToTensor(),
# transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) # Normalize images
# transforms.Lambda(lambda x: (x - 0.5) * 2) # unconment
])
transformed_batch = torch.stack([transform(example['image']) for example in batch])
# print("transformerd",transformed_batch.mean(),transformed_batch.min(),transformed_batch.max())
return transformed_batch,None

def train_dataloader(self):
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
collate_fn=self.collate_fn,
shuffle=True, num_workers=self.num_workers,
pin_memory=True)

def val_dataloader(self):
return DataLoader(self.val_dataset,
batch_size=self.batch_size,
collate_fn=self.collate_fn, num_workers=self.num_workers,
pin_memory=True)
Loading

0 comments on commit d418b6a

Please sign in to comment.