diff --git a/tdd_scheduler.py b/tdd_scheduler.py index 4c8964d..ec17346 100644 --- a/tdd_scheduler.py +++ b/tdd_scheduler.py @@ -2,18 +2,118 @@ from diffusers.schedulers.scheduling_tcd import * from diffusers.schedulers.scheduling_dpmsolver_singlestep import * -class TDDSchedulerPlus(DPMSolverSinglestepScheduler): +class TDDScheduler(DPMSolverSinglestepScheduler): + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + solver_order: int = 1, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = False, + use_karras_sigmas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + tdd_train_step: int = 250, + special_jump: bool = False, + t_l: int = -1 + ): + self.t_l = t_l + self.special_jump = special_jump + self.tdd_train_step = tdd_train_step + if algorithm_type == "dpmsolver": + deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message) + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.sample = None + self.order_list = self.get_order_list(num_train_timesteps) + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): self.num_inference_steps = num_inference_steps # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. - clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + #original_steps = self.config.original_inference_steps + if True: + original_steps=self.tdd_train_step + k = 1000 / original_steps + tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps) + 1))) * k - 1 + else: + tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps)))) + # TCD Inference Steps Schedule + tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy() + # Select (approximately) evenly spaced indices from tcd_origin_timesteps. + inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = tcd_origin_timesteps[inference_indices] + if self.special_jump: + if self.tdd_train_step == 50: + #timesteps = np.array([999., 879., 759., 499., 259.]) + print(timesteps) + elif self.tdd_train_step == 250: + if num_inference_steps == 5: + timesteps = np.array([999., 875., 751., 499., 251.]) + elif num_inference_steps == 6: + timesteps = np.array([999., 875., 751., 627., 499., 251.]) + elif num_inference_steps == 7: + timesteps = np.array([999., 875., 751., 627., 499., 375., 251.]) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: @@ -64,13 +164,29 @@ def set_timesteps_s(self, eta: float = 0.0): # This is critical for cosine (squaredcos_cap_v2) noise schedule. num_inference_steps = self.num_inference_steps device = self.timesteps.device - clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + if True: + original_steps=self.tdd_train_step + k = 1000 / original_steps + tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps) + 1))) * k - 1 + else: + tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps)))) + # TCD Inference Steps Schedule + tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy() + # Select (approximately) evenly spaced indices from tcd_origin_timesteps. + inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = tcd_origin_timesteps[inference_indices] + if self.special_jump: + if self.tdd_train_step == 50: + timesteps = np.array([999., 879., 759., 499., 259.]) + elif self.tdd_train_step == 250: + if num_inference_steps == 5: + timesteps = np.array([999., 875., 751., 499., 251.]) + elif num_inference_steps == 6: + timesteps = np.array([999., 875., 751., 627., 499., 251.]) + elif num_inference_steps == 7: + timesteps = np.array([999., 875., 751., 627., 499., 375., 251.]) + timesteps_s = np.floor((1 - eta) * timesteps).astype(np.int64) sigmas_s = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) @@ -267,111 +383,6 @@ def singlestep_dpm_solver_second_order_update( ) return x_t - def singlestep_dpm_solver_third_order_update( - self, - model_output_list: List[torch.FloatTensor], - *args, - sample: torch.FloatTensor = None, - **kwargs, - ) -> torch.FloatTensor: - """ - One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the - time `timestep_list[-3]`. - - Args: - model_output_list (`List[torch.FloatTensor]`): - The direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): - The current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - A current instance of a sample created by diffusion process. - - Returns: - `torch.FloatTensor`: - The sample tensor at the previous timestep. - """ - - timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing`sample` as a required keyward argument") - if timestep_list is not None: - deprecate( - "timestep_list", - "1.0.0", - "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( - self.sigmas[self.step_index + 1], - self.sigmas[self.step_index], - self.sigmas[self.step_index - 1], - self.sigmas[self.step_index - 2], - ) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) - alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) - lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) - - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - - h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 - r0, r1 = h_0 / h, h_1 / h - D0 = m2 - D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2) - D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1) - D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s2) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1_1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (sigma_t / sigma_s2) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 - ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s2) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s2) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - ) - return x_t - - def singlestep_dpm_solver_update( self, model_output_list: List[torch.FloatTensor], @@ -410,10 +421,8 @@ def singlestep_dpm_solver_update( return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample) elif order == 2: return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample) - elif order == 3: - return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample) else: - raise ValueError(f"Order must be 1, 2, 3, got {order}") + raise ValueError(f"Order must be 1, 2, got {order}") def convert_model_output( self, @@ -477,7 +486,7 @@ def convert_model_output( " `v_prediction` for the DPMSolverSinglestepScheduler." ) - if self.step_index == 0: + if self.step_index <= self.t_l: if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred)