From 69eceb53037118ec7dddc91814935cb92b48eb1d Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Thu, 5 Oct 2023 11:21:41 +0200 Subject: [PATCH] osx --- src/autotrain/cli/run_llm.py | 12 +++++++++--- src/autotrain/trainers/clm/__main__.py | 1 + 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/autotrain/cli/run_llm.py b/src/autotrain/cli/run_llm.py index c8816b5060..db209505d3 100644 --- a/src/autotrain/cli/run_llm.py +++ b/src/autotrain/cli/run_llm.py @@ -407,10 +407,16 @@ def __init__(self, args): break print(f"Bot: {tgi.chat(prompt)}") - if not torch.cuda.is_available(): - raise ValueError("No GPU found. Please install CUDA and try again.") + cuda_available = torch.cuda.is_available() + mps_available = torch.mps.is_available() - self.num_gpus = torch.cuda.device_count() + if not cuda_available and not mps_available: + raise ValueError("No GPU/MPS device found. LLM training requires an accelerator") + + if cuda_available: + self.num_gpus = torch.cuda.device_count() + elif mps_available: + self.num_gpus = 1 def run(self): from autotrain.backend import EndpointsRunner, SpaceRunner diff --git a/src/autotrain/trainers/clm/__main__.py b/src/autotrain/trainers/clm/__main__.py index 853c12a1b6..1b59f219e1 100644 --- a/src/autotrain/trainers/clm/__main__.py +++ b/src/autotrain/trainers/clm/__main__.py @@ -129,6 +129,7 @@ def train(config): trust_remote_code=True, use_flash_attention_2=config.use_flash_attention_2, ) + else: model = AutoModelForCausalLM.from_pretrained( config.model,