From 8799c126fcf7eb09d9ea279d6c8377524bdc59dc Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Mon, 27 Nov 2023 20:30:22 +0100 Subject: [PATCH 1/2] stuff --- src/autotrain/api.py | 131 +++++++++++++++++++++++++++---------------- 1 file changed, 84 insertions(+), 47 deletions(-) diff --git a/src/autotrain/api.py b/src/autotrain/api.py index f35121b3bb..ed442dc5f9 100644 --- a/src/autotrain/api.py +++ b/src/autotrain/api.py @@ -1,6 +1,9 @@ +import asyncio import json import os +import signal import subprocess +from contextlib import asynccontextmanager import psutil from fastapi import FastAPI @@ -25,13 +28,65 @@ PID = None -api = FastAPI() -logger.info(f"AUTOTRAIN_USERNAME: {AUTOTRAIN_USERNAME}") -logger.info(f"PROJECT_NAME: {PROJECT_NAME}") -logger.info(f"TASK_ID: {TASK_ID}") -logger.info(f"DATA_PATH: {DATA_PATH}") -logger.info(f"MODEL: {MODEL}") -logger.info(f"OUTPUT_MODEL_REPO: {OUTPUT_MODEL_REPO}") +class BackgroundRunner: + def __init__(self): + self.value = 0 + + async def run_main(self): + while True: + await monitor_training_process(PID) + await asyncio.sleep(0.1) + + +runner = BackgroundRunner() + + +def get_process_status(pid): + try: + process = psutil.Process(pid) + return process.status() + except psutil.NoSuchProcess: + logger.info(f"No process found with PID: {pid}") + return "Completed" + + +def kill_process(pid): + try: + parent_process = psutil.Process(pid) + children = parent_process.children(recursive=True) # This will get all the child processes recursively + + # First, terminate the child processes + for child in children: + child.terminate() + + # Wait for the child processes to terminate, and kill them if they don't + gone, still_alive = psutil.wait_procs(children, timeout=3) + for child in still_alive: + child.kill() + + # Now, terminate the parent process + parent_process.terminate() + parent_process.wait(timeout=5) + + logger.info(f"Process with pid {pid} and its children have been killed") + return f"Process with pid {pid} and its children have been killed" + + except psutil.NoSuchProcess: + logger.info(f"No process found with pid {pid}") + return f"No process found with pid {pid}" + + except psutil.TimeoutExpired: + logger.info(f"Process {pid} or one of its children has not terminated in time") + return f"Process {pid} or one of its children has not terminated in time" + + +async def monitor_training_process(pid: int): + while True: + status = get_process_status(pid) + if status == "Completed" or status == "Error": + logger.info("Training process finished. Shutting down the server.") + os.kill(os.getpid(), signal.SIGINT) + break def run_training(): @@ -138,50 +193,32 @@ def run_training(): return process.pid -def get_process_status(pid): - try: - process = psutil.Process(pid) - return process.status() - except psutil.NoSuchProcess: - return "No process found with PID: {}".format(pid) - - -def kill_process(pid): - try: - parent_process = psutil.Process(pid) - children = parent_process.children(recursive=True) # This will get all the child processes recursively - - # First, terminate the child processes - for child in children: - child.terminate() - - # Wait for the child processes to terminate, and kill them if they don't - gone, still_alive = psutil.wait_procs(children, timeout=3) - for child in still_alive: - child.kill() - - # Now, terminate the parent process - parent_process.terminate() - parent_process.wait(timeout=5) - - logger.info(f"Process with pid {pid} and its children have been killed") - return f"Process with pid {pid} and its children have been killed" - - except psutil.NoSuchProcess: - logger.info(f"No process found with pid {pid}") - return f"No process found with pid {pid}" - - except psutil.TimeoutExpired: - logger.info(f"Process {pid} or one of its children has not terminated in time") - return f"Process {pid} or one of its children has not terminated in time" - - -@api.on_event("startup") -async def startup_event(): +@asynccontextmanager +async def lifespan(app: FastAPI): process_pid = run_training() logger.info(f"Started training with PID {process_pid}") global PID PID = process_pid + asyncio.create_task(runner.run_main()) + # background_tasks.add_task(monitor_training_process, PID) + yield + + +api = FastAPI(lifespan=lifespan) +logger.info(f"AUTOTRAIN_USERNAME: {AUTOTRAIN_USERNAME}") +logger.info(f"PROJECT_NAME: {PROJECT_NAME}") +logger.info(f"TASK_ID: {TASK_ID}") +logger.info(f"DATA_PATH: {DATA_PATH}") +logger.info(f"MODEL: {MODEL}") +logger.info(f"OUTPUT_MODEL_REPO: {OUTPUT_MODEL_REPO}") + + +# @api.on_event("startup") +# async def startup_event(): +# process_pid = run_training() +# logger.info(f"Started training with PID {process_pid}") +# global PID +# PID = process_pid @api.get("/") From 50346cd94133f570d00834db59a5f9602e05412b Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Thu, 30 Nov 2023 15:06:53 +0100 Subject: [PATCH 2/2] enable ngc --- Dockerfile | 7 ++++++ requirements.txt | 2 +- src/autotrain/api.py | 27 ++++++++++++-------- src/autotrain/app.py | 7 +++++- src/autotrain/backend.py | 34 ++++++++++++++++++++++++-- src/autotrain/project.py | 4 +++ src/autotrain/trainers/clm/__main__.py | 4 +++ src/autotrain/trainers/clm/utils.py | 3 ++- templates/index.html | 12 ++++++--- 9 files changed, 81 insertions(+), 19 deletions(-) diff --git a/Dockerfile b/Dockerfile index ba7bf78e59..862a466ecd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,6 +5,8 @@ ENV DEBIAN_FRONTEND=noninteractive \ ENV PATH="${HOME}/miniconda3/bin:${PATH}" ARG PATH="${HOME}/miniconda3/bin:${PATH}" +ENV PATH="/app/ngc-cli:${PATH}" +ARG PATH="/app/ngc-cli:${PATH}" RUN mkdir -p /tmp/model && \ chown -R 1000:1000 /tmp/model && \ @@ -28,6 +30,7 @@ RUN apt-get update && \ git \ git-lfs \ libgl1 \ + unzip \ && rm -rf /var/lib/apt/lists/* && \ apt-get clean @@ -63,6 +66,10 @@ RUN conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c conda clean -ya && \ conda install -c "nvidia/label/cuda-12.1.0" cuda-nvcc && conda clean -ya +# install NGC CLI +RUN wget --content-disposition https://api.ngc.nvidia.com/v2/resources/nvidia/ngc-apps/ngc_cli/versions/3.34.1/files/ngccli_linux.zip -O ngccli_linux.zip && unzip ngccli_linux.zip && \ + chmod u+x ngc-cli/ngc + COPY --chown=1000:1000 . /app/ RUN pip install -e . && \ python -m nltk.downloader punkt && \ diff --git a/requirements.txt b/requirements.txt index 394b6f949d..ae06ff4d55 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ packaging==23.1 # latest versions tensorboard peft==0.6.2 -trl==0.7.2 +trl==0.7.4 tiktoken==0.5.1 transformers==4.35.1 accelerate==0.24.0 diff --git a/src/autotrain/api.py b/src/autotrain/api.py index ed442dc5f9..30af17c7ae 100644 --- a/src/autotrain/api.py +++ b/src/autotrain/api.py @@ -3,6 +3,7 @@ import os import signal import subprocess +import time from contextlib import asynccontextmanager import psutil @@ -34,8 +35,14 @@ def __init__(self): async def run_main(self): while True: - await monitor_training_process(PID) - await asyncio.sleep(0.1) + status = get_process_status(PID) + status = status.strip().lower() + if status in ("completed", "error", "zombie"): + logger.info("Training process finished. Shutting down the server.") + time.sleep(5) + kill_process(os.getpid()) + break + time.sleep(5) runner = BackgroundRunner() @@ -44,7 +51,9 @@ async def run_main(self): def get_process_status(pid): try: process = psutil.Process(pid) - return process.status() + proc_status = process.status() + logger.info(f"Process status: {proc_status}") + return proc_status except psutil.NoSuchProcess: logger.info(f"No process found with PID: {pid}") return "Completed" @@ -80,13 +89,11 @@ def kill_process(pid): return f"Process {pid} or one of its children has not terminated in time" -async def monitor_training_process(pid: int): - while True: - status = get_process_status(pid) - if status == "Completed" or status == "Error": - logger.info("Training process finished. Shutting down the server.") - os.kill(os.getpid(), signal.SIGINT) - break +def monitor_training_process(pid: int): + status = get_process_status(pid) + if status == "Completed" or status == "Error": + logger.info("Training process finished. Shutting down the server.") + os.kill(os.getpid(), signal.SIGINT) def run_training(): diff --git a/src/autotrain/app.py b/src/autotrain/app.py index f6726ffe17..b97f3cfda7 100644 --- a/src/autotrain/app.py +++ b/src/autotrain/app.py @@ -21,6 +21,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None) _, _, USERS = app_utils.user_validation() +ENABLE_NGC = int(os.environ.get("ENABLE_NGC", 0)) HIDDEN_PARAMS = [ @@ -132,7 +133,11 @@ async def read_form(request: Request): """ if HF_TOKEN is None: return templates.TemplateResponse("error.html", {"request": request}) - context = {"request": request, "valid_users": USERS} + context = { + "request": request, + "valid_users": USERS, + "enable_ngc": ENABLE_NGC, + } return templates.TemplateResponse("index.html", context) diff --git a/src/autotrain/backend.py b/src/autotrain/backend.py index 606394ae1b..f231dd03c7 100644 --- a/src/autotrain/backend.py +++ b/src/autotrain/backend.py @@ -429,8 +429,13 @@ class NGCRunner: def __post_init__(self): self.ngc_ace = os.environ.get("NGC_ACE") self.ngc_org = os.environ.get("NGC_ORG") + self.ngc_api_key = os.environ.get("NGC_CLI_API_KEY") + self.ngc_team = os.environ.get("NGC_TEAM") self.instance_map = { "dgx-a100": "dgxa100.80g.1.norm", + "dgx-2a100": "dgxa100.80g.2.norm", + "dgx-4a100": "dgxa100.80g.4.norm", + "dgx-8a100": "dgxa100.80g.8.norm", } logger.info("Creating NGC Job") logger.info(f"NGC_ACE: {self.ngc_ace}") @@ -455,5 +460,30 @@ def create(self): for k, v in self.env_vars.items(): cmd += f" --env-var {k}:{v}" - # run using subprocess, wait for completion - subprocess.run(cmd, shell=True, check=True) + ngc_config_cmd = "ngc config set" + ngc_config_cmd += " --team {ngc_team} --org {ngc_org} --ace {ngc_ace}" + ngc_config_cmd = ngc_config_cmd.format( + # ngc_api_key=self.ngc_api_key, + ngc_team=self.ngc_team, + ngc_org=self.ngc_org, + ngc_ace=self.ngc_ace, + ) + logger.info("Setting NGC API key") + ngc_config_process = subprocess.Popen(ngc_config_cmd, shell=True) + ngc_config_process.wait() + + if ngc_config_process.returncode == 0: + logger.info("NGC API key set successfully") + else: + logger.error("Failed to set NGC API key") + # print full output + logger.error(ngc_config_process.stdout.read()) + logger.error(ngc_config_process.stderr.read()) + raise Exception("Failed to set NGC API key") + + logger.info("Creating NGC Job") + subprocess.run( + cmd, + shell=True, + check=True, + ) diff --git a/src/autotrain/project.py b/src/autotrain/project.py index 3c24bd1dce..7668108b19 100644 --- a/src/autotrain/project.py +++ b/src/autotrain/project.py @@ -61,6 +61,10 @@ def __post_init__(self): "T4 Small": "spaces-t4s", "CPU Upgrade": "spaces-cpu", "CPU (Free)": "spaces-cpuf", + "DGX 1xA100": "dgx-a100", + "DGX 2xA100": "dgx-2a100", + "DGX 4xA100": "dgx-4a100", + "DGX 8xA100": "dgx-8a100", # "Local": "local", # "AutoTrain": "autotrain", } diff --git a/src/autotrain/trainers/clm/__main__.py b/src/autotrain/trainers/clm/__main__.py index d88e7a91bb..850c5ce6fa 100644 --- a/src/autotrain/trainers/clm/__main__.py +++ b/src/autotrain/trainers/clm/__main__.py @@ -200,8 +200,12 @@ def train(config): trust_remote_code=True, use_flash_attention_2=config.use_flash_attention_2, ) + else: + model_ref = None model.resize_token_embeddings(len(tokenizer)) + if model_ref is not None: + model_ref.resize_token_embeddings(len(tokenizer)) if config.use_peft: if config.use_int8 or config.use_int4: diff --git a/src/autotrain/trainers/clm/utils.py b/src/autotrain/trainers/clm/utils.py index f2abae074a..8c3e6eefe2 100644 --- a/src/autotrain/trainers/clm/utils.py +++ b/src/autotrain/trainers/clm/utils.py @@ -36,6 +36,7 @@ # Usage ```python + from transformers import AutoModelForCausalLM, AutoTokenizer model_path = "PATH_TO_THIS_REPO" @@ -58,7 +59,7 @@ # Model response: "Hello! How can I assist you today?" print(response) - +``` """ diff --git a/templates/index.html b/templates/index.html index 34f766e0a9..de46bb26e5 100644 --- a/templates/index.html +++ b/templates/index.html @@ -130,10 +130,14 @@ - + {% if enable_ngc == 1 %} + + + + + + + {% endif %}