Skip to content

Commit

Permalink
Update tdd_scheduler.py
Browse files Browse the repository at this point in the history
  • Loading branch information
WangCunzheng authored Aug 26, 2024
1 parent 5275b09 commit 5721fc6
Showing 1 changed file with 133 additions and 124 deletions.
257 changes: 133 additions & 124 deletions tdd_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,118 @@
from diffusers.schedulers.scheduling_tcd import *
from diffusers.schedulers.scheduling_dpmsolver_singlestep import *

class TDDSchedulerPlus(DPMSolverSinglestepScheduler):
class TDDScheduler(DPMSolverSinglestepScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
solver_order: int = 1,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
tdd_train_step: int = 250,
special_jump: bool = False,
t_l: int = -1
):
self.t_l = t_l
self.special_jump = special_jump
self.tdd_train_step = tdd_train_step
if algorithm_type == "dpmsolver":
deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message)

if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Currently we only support VP-type noise schedule
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0

# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
self.register_to_config(solver_type="midpoint")
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")

if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
raise ValueError(
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
)

# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.sample = None
self.order_list = self.get_order_list(num_train_timesteps)
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
self.num_inference_steps = num_inference_steps
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
#original_steps = self.config.original_inference_steps
if True:
original_steps=self.tdd_train_step
k = 1000 / original_steps
tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps) + 1))) * k - 1
else:
tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps))))
# TCD Inference Steps Schedule
tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy()
# Select (approximately) evenly spaced indices from tcd_origin_timesteps.
inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False)
inference_indices = np.floor(inference_indices).astype(np.int64)
timesteps = tcd_origin_timesteps[inference_indices]
if self.special_jump:
if self.tdd_train_step == 50:
#timesteps = np.array([999., 879., 759., 499., 259.])
print(timesteps)
elif self.tdd_train_step == 250:
if num_inference_steps == 5:
timesteps = np.array([999., 875., 751., 499., 251.])
elif num_inference_steps == 6:
timesteps = np.array([999., 875., 751., 627., 499., 251.])
elif num_inference_steps == 7:
timesteps = np.array([999., 875., 751., 627., 499., 375., 251.])

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas:
Expand Down Expand Up @@ -64,13 +164,29 @@ def set_timesteps_s(self, eta: float = 0.0):
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
num_inference_steps = self.num_inference_steps
device = self.timesteps.device
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
if True:
original_steps=self.tdd_train_step
k = 1000 / original_steps
tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps) + 1))) * k - 1
else:
tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps))))
# TCD Inference Steps Schedule
tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy()
# Select (approximately) evenly spaced indices from tcd_origin_timesteps.
inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False)
inference_indices = np.floor(inference_indices).astype(np.int64)
timesteps = tcd_origin_timesteps[inference_indices]
if self.special_jump:
if self.tdd_train_step == 50:
timesteps = np.array([999., 879., 759., 499., 259.])
elif self.tdd_train_step == 250:
if num_inference_steps == 5:
timesteps = np.array([999., 875., 751., 499., 251.])
elif num_inference_steps == 6:
timesteps = np.array([999., 875., 751., 627., 499., 251.])
elif num_inference_steps == 7:
timesteps = np.array([999., 875., 751., 627., 499., 375., 251.])

timesteps_s = np.floor((1 - eta) * timesteps).astype(np.int64)

sigmas_s = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
Expand Down Expand Up @@ -267,111 +383,6 @@ def singlestep_dpm_solver_second_order_update(
)
return x_t

def singlestep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.FloatTensor],
*args,
sample: torch.FloatTensor = None,
**kwargs,
) -> torch.FloatTensor:
"""
One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
time `timestep_list[-3]`.
Args:
model_output_list (`List[torch.FloatTensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by diffusion process.
Returns:
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""

timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing`sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)

if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)

sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
)

alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)

lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)

m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]

h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h
D0 = m2
D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2)
D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1)
D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1)
if self.config.algorithm_type == "dpmsolver++":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s2) * sample
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1_1
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s2) * sample
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "dpmsolver":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s2) * sample
- (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1
)
elif self.config.solver_type == "heun":
x_t = (
(alpha_t / alpha_s2) * sample
- (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
return x_t


def singlestep_dpm_solver_update(
self,
model_output_list: List[torch.FloatTensor],
Expand Down Expand Up @@ -410,10 +421,8 @@ def singlestep_dpm_solver_update(
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
elif order == 2:
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
elif order == 3:
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
else:
raise ValueError(f"Order must be 1, 2, 3, got {order}")
raise ValueError(f"Order must be 1, 2, got {order}")

def convert_model_output(
self,
Expand Down Expand Up @@ -477,7 +486,7 @@ def convert_model_output(
" `v_prediction` for the DPMSolverSinglestepScheduler."
)

if self.step_index == 0:
if self.step_index <= self.t_l:
if self.config.thresholding:
x0_pred = self._threshold_sample(x0_pred)

Expand Down

0 comments on commit 5721fc6

Please sign in to comment.