From 30e757bec72ef4ada03b2c6321eced51a1713041 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Wed, 22 May 2024 16:14:18 +0200 Subject: [PATCH] fix dreambooth colab --- src/autotrain/app/colab.py | 2 +- src/autotrain/commands.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/autotrain/app/colab.py b/src/autotrain/app/colab.py index 07a4c9cfc0..67d84329c2 100644 --- a/src/autotrain/app/colab.py +++ b/src/autotrain/app/colab.py @@ -351,7 +351,7 @@ def start_training(b): "backend": "local", "data": { "path": dataset_path.value, - "prompt": params_val["prompt"], + "prompt": prompt, }, "params": params_val, "hub": { diff --git a/src/autotrain/commands.py b/src/autotrain/commands.py index 71b81cd0b1..a5fbc3a7f5 100644 --- a/src/autotrain/commands.py +++ b/src/autotrain/commands.py @@ -17,6 +17,18 @@ def launch_command(params): + """ + Launches training command based on the given parameters. + + Args: + params: An instance of a parameter class (LLMTrainingParams, DreamBoothTrainingParams, GenericParams, TabularParams, + TextClassificationParams, TextRegressionParams, TokenClassificationParams, ImageClassificationParams, + ObjectDetectionParams, Seq2SeqParams). + + Returns: + None + """ + params.project_name = shlex.split(params.project_name)[0] cuda_available = torch.cuda.is_available() mps_available = torch.backends.mps.is_available()