From d5d6f30a9c1b0acc2962a4ed08f06110c75b8e55 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Sat, 21 Oct 2023 12:32:50 +0200 Subject: [PATCH] add colab arg --- colabs/AutoTrain_Dreambooth.ipynb | 2 +- colabs/AutoTrain_LLM.ipynb | 2 +- src/autotrain/cli/run_setup.py | 15 ++++++++++++--- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/colabs/AutoTrain_Dreambooth.ipynb b/colabs/AutoTrain_Dreambooth.ipynb index 411d5e4b70..e6c5af460b 100644 --- a/colabs/AutoTrain_Dreambooth.ipynb +++ b/colabs/AutoTrain_Dreambooth.ipynb @@ -26,7 +26,7 @@ "\n", "import os\n", "!pip install -U autotrain-advanced > install_logs.txt\n", - "!autotrain setup > setup_logs.txt" + "!autotrain setup --colab > setup_logs.txt" ] }, { diff --git a/colabs/AutoTrain_LLM.ipynb b/colabs/AutoTrain_LLM.ipynb index 5578d52485..82925a148e 100644 --- a/colabs/AutoTrain_LLM.ipynb +++ b/colabs/AutoTrain_LLM.ipynb @@ -22,7 +22,7 @@ "\n", "import os\n", "!pip install -U autotrain-advanced > install_logs.txt\n", - "!autotrain setup > setup_logs.txt" + "!autotrain setup --colab > setup_logs.txt" ] }, { diff --git a/src/autotrain/cli/run_setup.py b/src/autotrain/cli/run_setup.py index ee51a76423..286caba176 100644 --- a/src/autotrain/cli/run_setup.py +++ b/src/autotrain/cli/run_setup.py @@ -7,7 +7,7 @@ def run_app_command_factory(args): - return RunSetupCommand(args.update_torch) + return RunSetupCommand(args.update_torch, args.colab) class RunSetupCommand(BaseAutoTrainCommand): @@ -22,10 +22,16 @@ def register_subcommand(parser: ArgumentParser): action="store_true", help="Update PyTorch to latest version", ) + run_setup_parser.add_argument( + "--colab", + action="store_true", + help="Run setup for Google Colab", + ) run_setup_parser.set_defaults(func=run_app_command_factory) - def __init__(self, update_torch: bool): + def __init__(self, update_torch: bool, colab: bool = False): self.update_torch = update_torch + self.colab = colab def run(self): # install latest transformers @@ -53,7 +59,10 @@ def run(self): _, _ = pipe.communicate() logger.info("Successfully installed latest trl") - cmd = "pip install -U xformers" + if self.colab: + cmd = "pip install -U xformers" + else: + cmd = "pip uninstall -U xformers==0.0.22" pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) logger.info("Installing latest xformers") _, _ = pipe.communicate()