Skip to content

Commit

Permalink
DDPM can support training
Browse files Browse the repository at this point in the history
  • Loading branch information
CIntellifusion committed May 6, 2024
1 parent 2e0091a commit 79ff8d1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
# SimpleDiffusion
A simple diffusion framework for better understanding





# Task 1

Implement a DDPM and DDIM that can generate Celeb64Image



10 changes: 4 additions & 6 deletions UnconditionalDiffusion/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
autor: haoyu
date: 20240501-20240506
date: 20240501-0506
an simplified unconditional diffusion for image generation
"""
import os , cv2 ,argparse
Expand Down Expand Up @@ -43,7 +43,6 @@
### borrowed from https://github.com/SingleZombie/DL-Demos/blob/master/dldemos/ddim/network.py

class PositionalEncoding(nn.Module):

def __init__(self, max_seq_len: int, d_model: int):
super().__init__()

Expand Down Expand Up @@ -454,7 +453,7 @@ def sampling_backward(self, image_or_shape,net,device="cuda",simple_var=True):
x = self.sampling_step(net, x, t, simple_var)
return x
@torch.no_grad()
def sampling_step(self,net,x_t, t,simple_var,use_noise=False,clip_denoised=False):
def sampling_step(self,net,x_t, t,simple_var,use_noise=True,clip_denoised=False):
bs = x_t.shape[0]
t_tensor = t*torch.ones(bs,dtype=torch.long,device=x_t.device).reshape(-1,1)
if t == 0:
Expand All @@ -463,15 +462,14 @@ def sampling_step(self,net,x_t, t,simple_var,use_noise=False,clip_denoised=False
if simple_var:
var = self.betas[t]
else:
var = (1-self.alpha_bars_prev[t])/(1-self.alpha_bars[t])*self.betas[t]
var = (1-self.alpha_bars_prev[t])/(1-self.alpha_bars[t]) * self.betas[t]
#这个地方还真写错了 randn_like和rand_like不一样wor
noise = torch.randn_like(x_t) * torch.sqrt(var)
eps = net(x_t,t_tensor)
# with open("./cache.txt",'a') as f:
# f.write(f"{eps.mean().item()},{eps.max().item()},{eps.min().item()}\n")
eps = ((1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t])) *eps
mean = (x_t - eps)
mean/= torch.sqrt(self.alphas[t])
mean = (x_t - eps) / torch.sqrt(self.alphas[t])
# eps = torch.sqrt(1-self.alpha_bars[t]) * eps
# print(1-self.alpha_bars[t])
# mean = (x_t-eps)/torch.sqrt(self.alpha_bars[t])
Expand Down

0 comments on commit 79ff8d1

Please sign in to comment.