Skip to content

Commit

Permalink
support ileco
Browse files Browse the repository at this point in the history
  • Loading branch information
laksjdjf authored Jan 19, 2024
1 parent 8fd54a1 commit a13c1ab
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions leco.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def main(config):

sampling_step = config.leco.sampling_step
num_samples = config.leco.num_samples
noise_scheduler.set_timesteps(sampling_step, device=device)
if sampling_step > 0:
noise_scheduler.set_timesteps(sampling_step, device=device)

generate_guidance_scale = config.leco.generate_guidance_scale

Expand All @@ -142,6 +143,7 @@ def main(config):
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=collate_fn)

del text_model
torch.cuda.empty_cache()

total_steps = config.leco.epochs * len(dataloader)
save_steps = config.leco.save_steps
Expand Down Expand Up @@ -178,19 +180,22 @@ def main(config):
if len(latents_and_times[idx]) == 0:
# x_T
latents = torch.randn(batch_size, 4, resolution, resolution, device=device, dtype=weight_dtype)

timestep_to = random.sample(range(sampling_step), num_samples) # tをnum_samples個サンプリング
timestep_to.sort()
timedelta = random.choice(range(1000//sampling_step-1)) # 全ステップ学習できるようちょっとずらす
target_index = 0
for i, t in tqdm(enumerate(noise_scheduler.timesteps[0:timestep_to[-1]+1])):
timestep = t + timedelta
latents_input = noise_scheduler.scale_model_input(latents, timestep)
noise_pred = cfg(unet, latents_input, timestep, torch.cat([target, negative],dim=0), generate_guidance_scale, target_proj, negative_proj, size_condition)
latents = noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
if i == timestep_to[target_index]:
target_index += 1
latents_and_times[idx].append((latents, timestep))
if sampling_step > 0:
timestep_to = random.sample(range(sampling_step), num_samples) # tをnum_samples個サンプリング
timestep_to.sort()
timedelta = random.choice(range(1000//sampling_step-1)) # 全ステップ学習できるようちょっとずらす
target_index = 0
for i, t in tqdm(enumerate(noise_scheduler.timesteps[0:timestep_to[-1]+1])):
timestep = t + timedelta
latents_input = noise_scheduler.scale_model_input(latents, timestep)
noise_pred = cfg(unet, latents_input, timestep, torch.cat([target, negative],dim=0), generate_guidance_scale, target_proj, negative_proj, size_condition)
latents = noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
if i == timestep_to[target_index]:
target_index += 1
latents_and_times[idx].append((latents, timestep))
else:
timestep = torch.tensor([999]).to(latents).repeat(batch_size)
latents_and_times[idx].append((latents, timestep))

with torch.autocast("cuda", enabled=not config.train.amp == False):
latents, timesteps = latents_and_times[idx].pop()
Expand Down

0 comments on commit a13c1ab

Please sign in to comment.