diff --git a/leco.py b/leco.py index 87e295d..7f46827 100644 --- a/leco.py +++ b/leco.py @@ -131,7 +131,8 @@ def main(config): sampling_step = config.leco.sampling_step num_samples = config.leco.num_samples - noise_scheduler.set_timesteps(sampling_step, device=device) + if sampling_step > 0: + noise_scheduler.set_timesteps(sampling_step, device=device) generate_guidance_scale = config.leco.generate_guidance_scale @@ -142,6 +143,7 @@ def main(config): dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=collate_fn) del text_model + torch.cuda.empty_cache() total_steps = config.leco.epochs * len(dataloader) save_steps = config.leco.save_steps @@ -178,19 +180,22 @@ def main(config): if len(latents_and_times[idx]) == 0: # x_T latents = torch.randn(batch_size, 4, resolution, resolution, device=device, dtype=weight_dtype) - - timestep_to = random.sample(range(sampling_step), num_samples) # tをnum_samples個サンプリング - timestep_to.sort() - timedelta = random.choice(range(1000//sampling_step-1)) # 全ステップ学習できるようちょっとずらす - target_index = 0 - for i, t in tqdm(enumerate(noise_scheduler.timesteps[0:timestep_to[-1]+1])): - timestep = t + timedelta - latents_input = noise_scheduler.scale_model_input(latents, timestep) - noise_pred = cfg(unet, latents_input, timestep, torch.cat([target, negative],dim=0), generate_guidance_scale, target_proj, negative_proj, size_condition) - latents = noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] - if i == timestep_to[target_index]: - target_index += 1 - latents_and_times[idx].append((latents, timestep)) + if sampling_step > 0: + timestep_to = random.sample(range(sampling_step), num_samples) # tをnum_samples個サンプリング + timestep_to.sort() + timedelta = random.choice(range(1000//sampling_step-1)) # 全ステップ学習できるようちょっとずらす + target_index = 0 + for i, t in tqdm(enumerate(noise_scheduler.timesteps[0:timestep_to[-1]+1])): + timestep = t + timedelta + latents_input = noise_scheduler.scale_model_input(latents, timestep) + noise_pred = cfg(unet, latents_input, timestep, torch.cat([target, negative],dim=0), generate_guidance_scale, target_proj, negative_proj, size_condition) + latents = noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + if i == timestep_to[target_index]: + target_index += 1 + latents_and_times[idx].append((latents, timestep)) + else: + timestep = torch.tensor([999]).to(latents).repeat(batch_size) + latents_and_times[idx].append((latents, timestep)) with torch.autocast("cuda", enabled=not config.train.amp == False): latents, timesteps = latents_and_times[idx].pop()