diff --git a/Dockerfile b/Dockerfile index ba7bf78e59..862a466ecd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,6 +5,8 @@ ENV DEBIAN_FRONTEND=noninteractive \ ENV PATH="${HOME}/miniconda3/bin:${PATH}" ARG PATH="${HOME}/miniconda3/bin:${PATH}" +ENV PATH="/app/ngc-cli:${PATH}" +ARG PATH="/app/ngc-cli:${PATH}" RUN mkdir -p /tmp/model && \ chown -R 1000:1000 /tmp/model && \ @@ -28,6 +30,7 @@ RUN apt-get update && \ git \ git-lfs \ libgl1 \ + unzip \ && rm -rf /var/lib/apt/lists/* && \ apt-get clean @@ -63,6 +66,10 @@ RUN conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c conda clean -ya && \ conda install -c "nvidia/label/cuda-12.1.0" cuda-nvcc && conda clean -ya +# install NGC CLI +RUN wget --content-disposition https://api.ngc.nvidia.com/v2/resources/nvidia/ngc-apps/ngc_cli/versions/3.34.1/files/ngccli_linux.zip -O ngccli_linux.zip && unzip ngccli_linux.zip && \ + chmod u+x ngc-cli/ngc + COPY --chown=1000:1000 . /app/ RUN pip install -e . && \ python -m nltk.downloader punkt && \ diff --git a/requirements.txt b/requirements.txt index 394b6f949d..ae06ff4d55 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ packaging==23.1 # latest versions tensorboard peft==0.6.2 -trl==0.7.2 +trl==0.7.4 tiktoken==0.5.1 transformers==4.35.1 accelerate==0.24.0 diff --git a/src/autotrain/api.py b/src/autotrain/api.py index ed442dc5f9..30af17c7ae 100644 --- a/src/autotrain/api.py +++ b/src/autotrain/api.py @@ -3,6 +3,7 @@ import os import signal import subprocess +import time from contextlib import asynccontextmanager import psutil @@ -34,8 +35,14 @@ def __init__(self): async def run_main(self): while True: - await monitor_training_process(PID) - await asyncio.sleep(0.1) + status = get_process_status(PID) + status = status.strip().lower() + if status in ("completed", "error", "zombie"): + logger.info("Training process finished. Shutting down the server.") + time.sleep(5) + kill_process(os.getpid()) + break + time.sleep(5) runner = BackgroundRunner() @@ -44,7 +51,9 @@ async def run_main(self): def get_process_status(pid): try: process = psutil.Process(pid) - return process.status() + proc_status = process.status() + logger.info(f"Process status: {proc_status}") + return proc_status except psutil.NoSuchProcess: logger.info(f"No process found with PID: {pid}") return "Completed" @@ -80,13 +89,11 @@ def kill_process(pid): return f"Process {pid} or one of its children has not terminated in time" -async def monitor_training_process(pid: int): - while True: - status = get_process_status(pid) - if status == "Completed" or status == "Error": - logger.info("Training process finished. Shutting down the server.") - os.kill(os.getpid(), signal.SIGINT) - break +def monitor_training_process(pid: int): + status = get_process_status(pid) + if status == "Completed" or status == "Error": + logger.info("Training process finished. Shutting down the server.") + os.kill(os.getpid(), signal.SIGINT) def run_training(): diff --git a/src/autotrain/app.py b/src/autotrain/app.py index f6726ffe17..b97f3cfda7 100644 --- a/src/autotrain/app.py +++ b/src/autotrain/app.py @@ -21,6 +21,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None) _, _, USERS = app_utils.user_validation() +ENABLE_NGC = int(os.environ.get("ENABLE_NGC", 0)) HIDDEN_PARAMS = [ @@ -132,7 +133,11 @@ async def read_form(request: Request): """ if HF_TOKEN is None: return templates.TemplateResponse("error.html", {"request": request}) - context = {"request": request, "valid_users": USERS} + context = { + "request": request, + "valid_users": USERS, + "enable_ngc": ENABLE_NGC, + } return templates.TemplateResponse("index.html", context) diff --git a/src/autotrain/backend.py b/src/autotrain/backend.py index 606394ae1b..f231dd03c7 100644 --- a/src/autotrain/backend.py +++ b/src/autotrain/backend.py @@ -429,8 +429,13 @@ class NGCRunner: def __post_init__(self): self.ngc_ace = os.environ.get("NGC_ACE") self.ngc_org = os.environ.get("NGC_ORG") + self.ngc_api_key = os.environ.get("NGC_CLI_API_KEY") + self.ngc_team = os.environ.get("NGC_TEAM") self.instance_map = { "dgx-a100": "dgxa100.80g.1.norm", + "dgx-2a100": "dgxa100.80g.2.norm", + "dgx-4a100": "dgxa100.80g.4.norm", + "dgx-8a100": "dgxa100.80g.8.norm", } logger.info("Creating NGC Job") logger.info(f"NGC_ACE: {self.ngc_ace}") @@ -455,5 +460,30 @@ def create(self): for k, v in self.env_vars.items(): cmd += f" --env-var {k}:{v}" - # run using subprocess, wait for completion - subprocess.run(cmd, shell=True, check=True) + ngc_config_cmd = "ngc config set" + ngc_config_cmd += " --team {ngc_team} --org {ngc_org} --ace {ngc_ace}" + ngc_config_cmd = ngc_config_cmd.format( + # ngc_api_key=self.ngc_api_key, + ngc_team=self.ngc_team, + ngc_org=self.ngc_org, + ngc_ace=self.ngc_ace, + ) + logger.info("Setting NGC API key") + ngc_config_process = subprocess.Popen(ngc_config_cmd, shell=True) + ngc_config_process.wait() + + if ngc_config_process.returncode == 0: + logger.info("NGC API key set successfully") + else: + logger.error("Failed to set NGC API key") + # print full output + logger.error(ngc_config_process.stdout.read()) + logger.error(ngc_config_process.stderr.read()) + raise Exception("Failed to set NGC API key") + + logger.info("Creating NGC Job") + subprocess.run( + cmd, + shell=True, + check=True, + ) diff --git a/src/autotrain/project.py b/src/autotrain/project.py index 3c24bd1dce..7668108b19 100644 --- a/src/autotrain/project.py +++ b/src/autotrain/project.py @@ -61,6 +61,10 @@ def __post_init__(self): "T4 Small": "spaces-t4s", "CPU Upgrade": "spaces-cpu", "CPU (Free)": "spaces-cpuf", + "DGX 1xA100": "dgx-a100", + "DGX 2xA100": "dgx-2a100", + "DGX 4xA100": "dgx-4a100", + "DGX 8xA100": "dgx-8a100", # "Local": "local", # "AutoTrain": "autotrain", } diff --git a/src/autotrain/trainers/clm/__main__.py b/src/autotrain/trainers/clm/__main__.py index d88e7a91bb..850c5ce6fa 100644 --- a/src/autotrain/trainers/clm/__main__.py +++ b/src/autotrain/trainers/clm/__main__.py @@ -200,8 +200,12 @@ def train(config): trust_remote_code=True, use_flash_attention_2=config.use_flash_attention_2, ) + else: + model_ref = None model.resize_token_embeddings(len(tokenizer)) + if model_ref is not None: + model_ref.resize_token_embeddings(len(tokenizer)) if config.use_peft: if config.use_int8 or config.use_int4: diff --git a/src/autotrain/trainers/clm/utils.py b/src/autotrain/trainers/clm/utils.py index f2abae074a..8c3e6eefe2 100644 --- a/src/autotrain/trainers/clm/utils.py +++ b/src/autotrain/trainers/clm/utils.py @@ -36,6 +36,7 @@ # Usage ```python + from transformers import AutoModelForCausalLM, AutoTokenizer model_path = "PATH_TO_THIS_REPO" @@ -58,7 +59,7 @@ # Model response: "Hello! How can I assist you today?" print(response) - +``` """ diff --git a/templates/index.html b/templates/index.html index 34f766e0a9..de46bb26e5 100644 --- a/templates/index.html +++ b/templates/index.html @@ -130,10 +130,14 @@ - + {% if enable_ngc == 1 %} + + + + + + + {% endif %}