Skip to content

Commit

Permalink
add colab arg
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Oct 21, 2023
1 parent 461186a commit d5d6f30
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion colabs/AutoTrain_Dreambooth.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"\n",
"import os\n",
"!pip install -U autotrain-advanced > install_logs.txt\n",
"!autotrain setup > setup_logs.txt"
"!autotrain setup --colab > setup_logs.txt"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion colabs/AutoTrain_LLM.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"\n",
"import os\n",
"!pip install -U autotrain-advanced > install_logs.txt\n",
"!autotrain setup > setup_logs.txt"
"!autotrain setup --colab > setup_logs.txt"
]
},
{
Expand Down
15 changes: 12 additions & 3 deletions src/autotrain/cli/run_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def run_app_command_factory(args):
return RunSetupCommand(args.update_torch)
return RunSetupCommand(args.update_torch, args.colab)


class RunSetupCommand(BaseAutoTrainCommand):
Expand All @@ -22,10 +22,16 @@ def register_subcommand(parser: ArgumentParser):
action="store_true",
help="Update PyTorch to latest version",
)
run_setup_parser.add_argument(
"--colab",
action="store_true",
help="Run setup for Google Colab",
)
run_setup_parser.set_defaults(func=run_app_command_factory)

def __init__(self, update_torch: bool):
def __init__(self, update_torch: bool, colab: bool = False):
self.update_torch = update_torch
self.colab = colab

def run(self):
# install latest transformers
Expand Down Expand Up @@ -53,7 +59,10 @@ def run(self):
_, _ = pipe.communicate()
logger.info("Successfully installed latest trl")

cmd = "pip install -U xformers"
if self.colab:
cmd = "pip install -U xformers"
else:
cmd = "pip uninstall -U xformers==0.0.22"
pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
logger.info("Installing latest xformers")
_, _ = pipe.communicate()
Expand Down

0 comments on commit d5d6f30

Please sign in to comment.