diff --git a/Dockerfile b/Dockerfile
index ba7bf78e59..862a466ecd 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -5,6 +5,8 @@ ENV DEBIAN_FRONTEND=noninteractive \
ENV PATH="${HOME}/miniconda3/bin:${PATH}"
ARG PATH="${HOME}/miniconda3/bin:${PATH}"
+ENV PATH="/app/ngc-cli:${PATH}"
+ARG PATH="/app/ngc-cli:${PATH}"
RUN mkdir -p /tmp/model && \
chown -R 1000:1000 /tmp/model && \
@@ -28,6 +30,7 @@ RUN apt-get update && \
git \
git-lfs \
libgl1 \
+ unzip \
&& rm -rf /var/lib/apt/lists/* && \
apt-get clean
@@ -63,6 +66,10 @@ RUN conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c
conda clean -ya && \
conda install -c "nvidia/label/cuda-12.1.0" cuda-nvcc && conda clean -ya
+# install NGC CLI
+RUN wget --content-disposition https://api.ngc.nvidia.com/v2/resources/nvidia/ngc-apps/ngc_cli/versions/3.34.1/files/ngccli_linux.zip -O ngccli_linux.zip && unzip ngccli_linux.zip && \
+ chmod u+x ngc-cli/ngc
+
COPY --chown=1000:1000 . /app/
RUN pip install -e . && \
python -m nltk.downloader punkt && \
diff --git a/requirements.txt b/requirements.txt
index 394b6f949d..ae06ff4d55 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -25,7 +25,7 @@ packaging==23.1
# latest versions
tensorboard
peft==0.6.2
-trl==0.7.2
+trl==0.7.4
tiktoken==0.5.1
transformers==4.35.1
accelerate==0.24.0
diff --git a/src/autotrain/api.py b/src/autotrain/api.py
index ed442dc5f9..30af17c7ae 100644
--- a/src/autotrain/api.py
+++ b/src/autotrain/api.py
@@ -3,6 +3,7 @@
import os
import signal
import subprocess
+import time
from contextlib import asynccontextmanager
import psutil
@@ -34,8 +35,14 @@ def __init__(self):
async def run_main(self):
while True:
- await monitor_training_process(PID)
- await asyncio.sleep(0.1)
+ status = get_process_status(PID)
+ status = status.strip().lower()
+ if status in ("completed", "error", "zombie"):
+ logger.info("Training process finished. Shutting down the server.")
+ time.sleep(5)
+ kill_process(os.getpid())
+ break
+ time.sleep(5)
runner = BackgroundRunner()
@@ -44,7 +51,9 @@ async def run_main(self):
def get_process_status(pid):
try:
process = psutil.Process(pid)
- return process.status()
+ proc_status = process.status()
+ logger.info(f"Process status: {proc_status}")
+ return proc_status
except psutil.NoSuchProcess:
logger.info(f"No process found with PID: {pid}")
return "Completed"
@@ -80,13 +89,11 @@ def kill_process(pid):
return f"Process {pid} or one of its children has not terminated in time"
-async def monitor_training_process(pid: int):
- while True:
- status = get_process_status(pid)
- if status == "Completed" or status == "Error":
- logger.info("Training process finished. Shutting down the server.")
- os.kill(os.getpid(), signal.SIGINT)
- break
+def monitor_training_process(pid: int):
+ status = get_process_status(pid)
+ if status == "Completed" or status == "Error":
+ logger.info("Training process finished. Shutting down the server.")
+ os.kill(os.getpid(), signal.SIGINT)
def run_training():
diff --git a/src/autotrain/app.py b/src/autotrain/app.py
index f6726ffe17..b97f3cfda7 100644
--- a/src/autotrain/app.py
+++ b/src/autotrain/app.py
@@ -21,6 +21,7 @@
HF_TOKEN = os.environ.get("HF_TOKEN", None)
_, _, USERS = app_utils.user_validation()
+ENABLE_NGC = int(os.environ.get("ENABLE_NGC", 0))
HIDDEN_PARAMS = [
@@ -132,7 +133,11 @@ async def read_form(request: Request):
"""
if HF_TOKEN is None:
return templates.TemplateResponse("error.html", {"request": request})
- context = {"request": request, "valid_users": USERS}
+ context = {
+ "request": request,
+ "valid_users": USERS,
+ "enable_ngc": ENABLE_NGC,
+ }
return templates.TemplateResponse("index.html", context)
diff --git a/src/autotrain/backend.py b/src/autotrain/backend.py
index 606394ae1b..f231dd03c7 100644
--- a/src/autotrain/backend.py
+++ b/src/autotrain/backend.py
@@ -429,8 +429,13 @@ class NGCRunner:
def __post_init__(self):
self.ngc_ace = os.environ.get("NGC_ACE")
self.ngc_org = os.environ.get("NGC_ORG")
+ self.ngc_api_key = os.environ.get("NGC_CLI_API_KEY")
+ self.ngc_team = os.environ.get("NGC_TEAM")
self.instance_map = {
"dgx-a100": "dgxa100.80g.1.norm",
+ "dgx-2a100": "dgxa100.80g.2.norm",
+ "dgx-4a100": "dgxa100.80g.4.norm",
+ "dgx-8a100": "dgxa100.80g.8.norm",
}
logger.info("Creating NGC Job")
logger.info(f"NGC_ACE: {self.ngc_ace}")
@@ -455,5 +460,30 @@ def create(self):
for k, v in self.env_vars.items():
cmd += f" --env-var {k}:{v}"
- # run using subprocess, wait for completion
- subprocess.run(cmd, shell=True, check=True)
+ ngc_config_cmd = "ngc config set"
+ ngc_config_cmd += " --team {ngc_team} --org {ngc_org} --ace {ngc_ace}"
+ ngc_config_cmd = ngc_config_cmd.format(
+ # ngc_api_key=self.ngc_api_key,
+ ngc_team=self.ngc_team,
+ ngc_org=self.ngc_org,
+ ngc_ace=self.ngc_ace,
+ )
+ logger.info("Setting NGC API key")
+ ngc_config_process = subprocess.Popen(ngc_config_cmd, shell=True)
+ ngc_config_process.wait()
+
+ if ngc_config_process.returncode == 0:
+ logger.info("NGC API key set successfully")
+ else:
+ logger.error("Failed to set NGC API key")
+ # print full output
+ logger.error(ngc_config_process.stdout.read())
+ logger.error(ngc_config_process.stderr.read())
+ raise Exception("Failed to set NGC API key")
+
+ logger.info("Creating NGC Job")
+ subprocess.run(
+ cmd,
+ shell=True,
+ check=True,
+ )
diff --git a/src/autotrain/project.py b/src/autotrain/project.py
index 3c24bd1dce..7668108b19 100644
--- a/src/autotrain/project.py
+++ b/src/autotrain/project.py
@@ -61,6 +61,10 @@ def __post_init__(self):
"T4 Small": "spaces-t4s",
"CPU Upgrade": "spaces-cpu",
"CPU (Free)": "spaces-cpuf",
+ "DGX 1xA100": "dgx-a100",
+ "DGX 2xA100": "dgx-2a100",
+ "DGX 4xA100": "dgx-4a100",
+ "DGX 8xA100": "dgx-8a100",
# "Local": "local",
# "AutoTrain": "autotrain",
}
diff --git a/src/autotrain/trainers/clm/__main__.py b/src/autotrain/trainers/clm/__main__.py
index d88e7a91bb..850c5ce6fa 100644
--- a/src/autotrain/trainers/clm/__main__.py
+++ b/src/autotrain/trainers/clm/__main__.py
@@ -200,8 +200,12 @@ def train(config):
trust_remote_code=True,
use_flash_attention_2=config.use_flash_attention_2,
)
+ else:
+ model_ref = None
model.resize_token_embeddings(len(tokenizer))
+ if model_ref is not None:
+ model_ref.resize_token_embeddings(len(tokenizer))
if config.use_peft:
if config.use_int8 or config.use_int4:
diff --git a/src/autotrain/trainers/clm/utils.py b/src/autotrain/trainers/clm/utils.py
index f2abae074a..8c3e6eefe2 100644
--- a/src/autotrain/trainers/clm/utils.py
+++ b/src/autotrain/trainers/clm/utils.py
@@ -36,6 +36,7 @@
# Usage
```python
+
from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = "PATH_TO_THIS_REPO"
@@ -58,7 +59,7 @@
# Model response: "Hello! How can I assist you today?"
print(response)
-
+```
"""
diff --git a/templates/index.html b/templates/index.html
index 34f766e0a9..de46bb26e5 100644
--- a/templates/index.html
+++ b/templates/index.html
@@ -130,10 +130,14 @@
-
+ {% if enable_ngc == 1 %}
+
+ {% endif %}