Skip to content

Commit

Permalink
formal vae encoder decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
CIntellifusion committed May 24, 2024
1 parent a342a6a commit 38a55d6
Show file tree
Hide file tree
Showing 3 changed files with 585 additions and 77 deletions.
26 changes: 21 additions & 5 deletions LatentDiffusion/data/data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torchvision.datasets import MNIST
import torch
from datasets import load_dataset
imsize = 64
### data
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir="./", batch_size=64,num_workers=63):
Expand Down Expand Up @@ -62,12 +63,26 @@ def split_dataset(self, dataset, split_ratio=0.2):
return train_dataset, val_dataset

def prepare_data(self):
self.dataset = load_dataset('nielsr/CelebA-faces')

self.dataset = load_dataset('nielsr/CelebA-faces')#.map(self.apply_transform)
# self.dataset = self.dataset.with_transform(self.apply_transform)
def setup(self, stage=None, transform=None):
if stage == 'fit' or stage is None:
self.train_dataset, self.val_dataset = self.split_dataset(self.dataset['train'], split_ratio=0.2)

@staticmethod
def collate_fn(batch):
# for example in batch:
# image = example['image']
# image.save("/home/haoyu/research/simplemodels/cache/test.jpg")
transform = transforms.Compose([
transforms.Resize((imsize,imsize)),
transforms.ToTensor(),
# transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) # Normalize images
# transforms.Lambda(lambda x: (x - 0.5) * 2) # unconment
])
transformed_batch = torch.stack([transform(example['image']) for example in batch])
# print("transformerd",transformed_batch.mean(),transformed_batch.min(),transformed_batch.max())

return transformed_batch,None

def apply_transform(self, example):
transform = transforms.Compose([
Expand All @@ -76,7 +91,7 @@ def apply_transform(self, example):
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize images
# transforms.Lambda(lambda x: (x - 0.5) * 2) # Uncomment to normalize
])
return transform(example['image'])
return {"image":transform(example['image'])}

def train_dataloader(self):
return DataLoader(self.train_dataset,
Expand All @@ -88,5 +103,6 @@ def train_dataloader(self):
def val_dataloader(self):
return DataLoader(self.val_dataset,
batch_size=self.batch_size,
collate_fn=self.collate_fn, num_workers=self.num_workers,
collate_fn=self.collate_fn,
num_workers=self.num_workers,
pin_memory=True)
Loading

0 comments on commit 38a55d6

Please sign in to comment.