Skip to content

Commit

Permalink
fix dpo model_ref / ngc (#368)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur authored Nov 30, 2023
1 parent a37bc1e commit 4cc1ea4
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 56 deletions.
7 changes: 7 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ ENV DEBIAN_FRONTEND=noninteractive \

ENV PATH="${HOME}/miniconda3/bin:${PATH}"
ARG PATH="${HOME}/miniconda3/bin:${PATH}"
ENV PATH="/app/ngc-cli:${PATH}"
ARG PATH="/app/ngc-cli:${PATH}"

RUN mkdir -p /tmp/model && \
chown -R 1000:1000 /tmp/model && \
Expand All @@ -28,6 +30,7 @@ RUN apt-get update && \
git \
git-lfs \
libgl1 \
unzip \
&& rm -rf /var/lib/apt/lists/* && \
apt-get clean

Expand Down Expand Up @@ -63,6 +66,10 @@ RUN conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c
conda clean -ya && \
conda install -c "nvidia/label/cuda-12.1.0" cuda-nvcc && conda clean -ya

# install NGC CLI
RUN wget --content-disposition https://api.ngc.nvidia.com/v2/resources/nvidia/ngc-apps/ngc_cli/versions/3.34.1/files/ngccli_linux.zip -O ngccli_linux.zip && unzip ngccli_linux.zip && \
chmod u+x ngc-cli/ngc

COPY --chown=1000:1000 . /app/
RUN pip install -e . && \
python -m nltk.downloader punkt && \
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ packaging==23.1
# latest versions
tensorboard
peft==0.6.2
trl==0.7.2
trl==0.7.4
tiktoken==0.5.1
transformers==4.35.1
accelerate==0.24.0
Expand Down
138 changes: 91 additions & 47 deletions src/autotrain/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import asyncio
import json
import os
import signal
import subprocess
import time
from contextlib import asynccontextmanager

import psutil
from fastapi import FastAPI
Expand All @@ -25,13 +29,71 @@
PID = None


api = FastAPI()
logger.info(f"AUTOTRAIN_USERNAME: {AUTOTRAIN_USERNAME}")
logger.info(f"PROJECT_NAME: {PROJECT_NAME}")
logger.info(f"TASK_ID: {TASK_ID}")
logger.info(f"DATA_PATH: {DATA_PATH}")
logger.info(f"MODEL: {MODEL}")
logger.info(f"OUTPUT_MODEL_REPO: {OUTPUT_MODEL_REPO}")
class BackgroundRunner:
def __init__(self):
self.value = 0

async def run_main(self):
while True:
status = get_process_status(PID)
status = status.strip().lower()
if status in ("completed", "error", "zombie"):
logger.info("Training process finished. Shutting down the server.")
time.sleep(5)
kill_process(os.getpid())
break
time.sleep(5)


runner = BackgroundRunner()


def get_process_status(pid):
try:
process = psutil.Process(pid)
proc_status = process.status()
logger.info(f"Process status: {proc_status}")
return proc_status
except psutil.NoSuchProcess:
logger.info(f"No process found with PID: {pid}")
return "Completed"


def kill_process(pid):
try:
parent_process = psutil.Process(pid)
children = parent_process.children(recursive=True) # This will get all the child processes recursively

# First, terminate the child processes
for child in children:
child.terminate()

# Wait for the child processes to terminate, and kill them if they don't
gone, still_alive = psutil.wait_procs(children, timeout=3)
for child in still_alive:
child.kill()

# Now, terminate the parent process
parent_process.terminate()
parent_process.wait(timeout=5)

logger.info(f"Process with pid {pid} and its children have been killed")
return f"Process with pid {pid} and its children have been killed"

except psutil.NoSuchProcess:
logger.info(f"No process found with pid {pid}")
return f"No process found with pid {pid}"

except psutil.TimeoutExpired:
logger.info(f"Process {pid} or one of its children has not terminated in time")
return f"Process {pid} or one of its children has not terminated in time"


def monitor_training_process(pid: int):
status = get_process_status(pid)
if status == "Completed" or status == "Error":
logger.info("Training process finished. Shutting down the server.")
os.kill(os.getpid(), signal.SIGINT)


def run_training():
Expand Down Expand Up @@ -138,50 +200,32 @@ def run_training():
return process.pid


def get_process_status(pid):
try:
process = psutil.Process(pid)
return process.status()
except psutil.NoSuchProcess:
return "No process found with PID: {}".format(pid)


def kill_process(pid):
try:
parent_process = psutil.Process(pid)
children = parent_process.children(recursive=True) # This will get all the child processes recursively

# First, terminate the child processes
for child in children:
child.terminate()

# Wait for the child processes to terminate, and kill them if they don't
gone, still_alive = psutil.wait_procs(children, timeout=3)
for child in still_alive:
child.kill()

# Now, terminate the parent process
parent_process.terminate()
parent_process.wait(timeout=5)

logger.info(f"Process with pid {pid} and its children have been killed")
return f"Process with pid {pid} and its children have been killed"

except psutil.NoSuchProcess:
logger.info(f"No process found with pid {pid}")
return f"No process found with pid {pid}"

except psutil.TimeoutExpired:
logger.info(f"Process {pid} or one of its children has not terminated in time")
return f"Process {pid} or one of its children has not terminated in time"


@api.on_event("startup")
async def startup_event():
@asynccontextmanager
async def lifespan(app: FastAPI):
process_pid = run_training()
logger.info(f"Started training with PID {process_pid}")
global PID
PID = process_pid
asyncio.create_task(runner.run_main())
# background_tasks.add_task(monitor_training_process, PID)
yield


api = FastAPI(lifespan=lifespan)
logger.info(f"AUTOTRAIN_USERNAME: {AUTOTRAIN_USERNAME}")
logger.info(f"PROJECT_NAME: {PROJECT_NAME}")
logger.info(f"TASK_ID: {TASK_ID}")
logger.info(f"DATA_PATH: {DATA_PATH}")
logger.info(f"MODEL: {MODEL}")
logger.info(f"OUTPUT_MODEL_REPO: {OUTPUT_MODEL_REPO}")


# @api.on_event("startup")
# async def startup_event():
# process_pid = run_training()
# logger.info(f"Started training with PID {process_pid}")
# global PID
# PID = process_pid


@api.get("/")
Expand Down
7 changes: 6 additions & 1 deletion src/autotrain/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

HF_TOKEN = os.environ.get("HF_TOKEN", None)
_, _, USERS = app_utils.user_validation()
ENABLE_NGC = int(os.environ.get("ENABLE_NGC", 0))


HIDDEN_PARAMS = [
Expand Down Expand Up @@ -132,7 +133,11 @@ async def read_form(request: Request):
"""
if HF_TOKEN is None:
return templates.TemplateResponse("error.html", {"request": request})
context = {"request": request, "valid_users": USERS}
context = {
"request": request,
"valid_users": USERS,
"enable_ngc": ENABLE_NGC,
}
return templates.TemplateResponse("index.html", context)


Expand Down
34 changes: 32 additions & 2 deletions src/autotrain/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,13 @@ class NGCRunner:
def __post_init__(self):
self.ngc_ace = os.environ.get("NGC_ACE")
self.ngc_org = os.environ.get("NGC_ORG")
self.ngc_api_key = os.environ.get("NGC_CLI_API_KEY")
self.ngc_team = os.environ.get("NGC_TEAM")
self.instance_map = {
"dgx-a100": "dgxa100.80g.1.norm",
"dgx-2a100": "dgxa100.80g.2.norm",
"dgx-4a100": "dgxa100.80g.4.norm",
"dgx-8a100": "dgxa100.80g.8.norm",
}
logger.info("Creating NGC Job")
logger.info(f"NGC_ACE: {self.ngc_ace}")
Expand All @@ -455,5 +460,30 @@ def create(self):
for k, v in self.env_vars.items():
cmd += f" --env-var {k}:{v}"

# run using subprocess, wait for completion
subprocess.run(cmd, shell=True, check=True)
ngc_config_cmd = "ngc config set"
ngc_config_cmd += " --team {ngc_team} --org {ngc_org} --ace {ngc_ace}"
ngc_config_cmd = ngc_config_cmd.format(
# ngc_api_key=self.ngc_api_key,
ngc_team=self.ngc_team,
ngc_org=self.ngc_org,
ngc_ace=self.ngc_ace,
)
logger.info("Setting NGC API key")
ngc_config_process = subprocess.Popen(ngc_config_cmd, shell=True)
ngc_config_process.wait()

if ngc_config_process.returncode == 0:
logger.info("NGC API key set successfully")
else:
logger.error("Failed to set NGC API key")
# print full output
logger.error(ngc_config_process.stdout.read())
logger.error(ngc_config_process.stderr.read())
raise Exception("Failed to set NGC API key")

logger.info("Creating NGC Job")
subprocess.run(
cmd,
shell=True,
check=True,
)
4 changes: 4 additions & 0 deletions src/autotrain/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def __post_init__(self):
"T4 Small": "spaces-t4s",
"CPU Upgrade": "spaces-cpu",
"CPU (Free)": "spaces-cpuf",
"DGX 1xA100": "dgx-a100",
"DGX 2xA100": "dgx-2a100",
"DGX 4xA100": "dgx-4a100",
"DGX 8xA100": "dgx-8a100",
# "Local": "local",
# "AutoTrain": "autotrain",
}
Expand Down
4 changes: 4 additions & 0 deletions src/autotrain/trainers/clm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,12 @@ def train(config):
trust_remote_code=True,
use_flash_attention_2=config.use_flash_attention_2,
)
else:
model_ref = None

model.resize_token_embeddings(len(tokenizer))
if model_ref is not None:
model_ref.resize_token_embeddings(len(tokenizer))

if config.use_peft:
if config.use_int8 or config.use_int4:
Expand Down
3 changes: 2 additions & 1 deletion src/autotrain/trainers/clm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# Usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = "PATH_TO_THIS_REPO"
Expand All @@ -58,7 +59,7 @@
# Model response: "Hello! How can I assist you today?"
print(response)
```
"""

Expand Down
12 changes: 8 additions & 4 deletions templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,14 @@
<option value="CPU Upgrade">CPU Upgrade</option>
<option value="CPU (Free)">CPU (Free)</option>
</optgroup>
<!-- <optgroup label="DGX Cloud">
<option value="A100 Large">4xA100 DGX</option>
<option value="A10G Large">2xA100 DGX</option>
</optgroup> -->
{% if enable_ngc == 1 %}
<optgroup label="DGX Cloud">
<option value="DGX 1xA100">1xA100 DGX</option>
<option value="DGX 2xA100">2xA100 DGX</option>
<option value="DGX 4xA100">4xA100 DGX</option>
<option value="DGX 8xA100">8xA100 DGX</option>
</optgroup>
{% endif %}
</select>
</div>
</div>
Expand Down

0 comments on commit 4cc1ea4

Please sign in to comment.