Skip to content

Commit

Permalink
Merge branch 'main' of github.com:huggingface/autotrain-advanced into…
Browse files Browse the repository at this point in the history
… main
  • Loading branch information
abhishekkrthakur committed Sep 28, 2023
2 parents 6928c10 + 89ac6b7 commit 63c356d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/autotrain/trainers/dreambooth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _get_model_pred(self, batch, channels, noisy_model_input, timesteps, bsz):
prompt=None,
text_input_ids_list=[self.tokens_one, self.tokens_two],
)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(bsz, 1)})
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1)
model_pred = self.unet(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
Expand Down

0 comments on commit 63c356d

Please sign in to comment.