From 89ac6b7eac955fcc311ef2b431283a341fa1b9a8 Mon Sep 17 00:00:00 2001 From: Marc DeMory Date: Thu, 28 Sep 2023 08:31:01 -0500 Subject: [PATCH] Enable using both --train-text-encoder and --use-prior-preservation (#245) --- src/autotrain/trainers/dreambooth/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/autotrain/trainers/dreambooth/trainer.py b/src/autotrain/trainers/dreambooth/trainer.py index dfcb19a931..ba22caad59 100644 --- a/src/autotrain/trainers/dreambooth/trainer.py +++ b/src/autotrain/trainers/dreambooth/trainer.py @@ -302,7 +302,7 @@ def _get_model_pred(self, batch, channels, noisy_model_input, timesteps, bsz): prompt=None, text_input_ids_list=[self.tokens_one, self.tokens_two], ) - unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(bsz, 1)}) + unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)}) prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1) model_pred = self.unet( noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions