diff --git a/src/autotrain/trainers/dreambooth/trainer.py b/src/autotrain/trainers/dreambooth/trainer.py index 2983aa2131..cc845aa3c9 100644 --- a/src/autotrain/trainers/dreambooth/trainer.py +++ b/src/autotrain/trainers/dreambooth/trainer.py @@ -302,7 +302,7 @@ def _get_model_pred(self, batch, channels, noisy_model_input, timesteps, bsz): prompt=None, text_input_ids_list=[self.tokens_one, self.tokens_two], ) - unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(bsz, 1)}) + unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)}) prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1) model_pred = self.unet( noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions