Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix dpo model_ref / ngc #368

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ ENV DEBIAN_FRONTEND=noninteractive \

ENV PATH="${HOME}/miniconda3/bin:${PATH}"
ARG PATH="${HOME}/miniconda3/bin:${PATH}"
ENV PATH="/app/ngc-cli:${PATH}"
ARG PATH="/app/ngc-cli:${PATH}"

RUN mkdir -p /tmp/model && \
chown -R 1000:1000 /tmp/model && \
Expand All @@ -28,6 +30,7 @@ RUN apt-get update && \
git \
git-lfs \
libgl1 \
unzip \
&& rm -rf /var/lib/apt/lists/* && \
apt-get clean

Expand Down Expand Up @@ -63,6 +66,10 @@ RUN conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c
conda clean -ya && \
conda install -c "nvidia/label/cuda-12.1.0" cuda-nvcc && conda clean -ya

# install NGC CLI
RUN wget --content-disposition https://api.ngc.nvidia.com/v2/resources/nvidia/ngc-apps/ngc_cli/versions/3.34.1/files/ngccli_linux.zip -O ngccli_linux.zip && unzip ngccli_linux.zip && \
chmod u+x ngc-cli/ngc

COPY --chown=1000:1000 . /app/
RUN pip install -e . && \
python -m nltk.downloader punkt && \
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ packaging==23.1
# latest versions
tensorboard
peft==0.6.2
trl==0.7.2
trl==0.7.4
tiktoken==0.5.1
transformers==4.35.1
accelerate==0.24.0
Expand Down
138 changes: 91 additions & 47 deletions src/autotrain/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import asyncio
import json
import os
import signal
import subprocess
import time
from contextlib import asynccontextmanager

import psutil
from fastapi import FastAPI
Expand All @@ -25,13 +29,71 @@
PID = None


api = FastAPI()
logger.info(f"AUTOTRAIN_USERNAME: {AUTOTRAIN_USERNAME}")
logger.info(f"PROJECT_NAME: {PROJECT_NAME}")
logger.info(f"TASK_ID: {TASK_ID}")
logger.info(f"DATA_PATH: {DATA_PATH}")
logger.info(f"MODEL: {MODEL}")
logger.info(f"OUTPUT_MODEL_REPO: {OUTPUT_MODEL_REPO}")
class BackgroundRunner:
def __init__(self):
self.value = 0

async def run_main(self):
while True:
status = get_process_status(PID)
status = status.strip().lower()
if status in ("completed", "error", "zombie"):
logger.info("Training process finished. Shutting down the server.")
time.sleep(5)
kill_process(os.getpid())
break
time.sleep(5)


runner = BackgroundRunner()


def get_process_status(pid):
try:
process = psutil.Process(pid)
proc_status = process.status()
logger.info(f"Process status: {proc_status}")
return proc_status
except psutil.NoSuchProcess:
logger.info(f"No process found with PID: {pid}")
return "Completed"


def kill_process(pid):
try:
parent_process = psutil.Process(pid)
children = parent_process.children(recursive=True) # This will get all the child processes recursively

# First, terminate the child processes
for child in children:
child.terminate()

# Wait for the child processes to terminate, and kill them if they don't
gone, still_alive = psutil.wait_procs(children, timeout=3)
for child in still_alive:
child.kill()

# Now, terminate the parent process
parent_process.terminate()
parent_process.wait(timeout=5)

logger.info(f"Process with pid {pid} and its children have been killed")
return f"Process with pid {pid} and its children have been killed"

except psutil.NoSuchProcess:
logger.info(f"No process found with pid {pid}")
return f"No process found with pid {pid}"

except psutil.TimeoutExpired:
logger.info(f"Process {pid} or one of its children has not terminated in time")
return f"Process {pid} or one of its children has not terminated in time"


def monitor_training_process(pid: int):
status = get_process_status(pid)
if status == "Completed" or status == "Error":
logger.info("Training process finished. Shutting down the server.")
os.kill(os.getpid(), signal.SIGINT)


def run_training():
Expand Down Expand Up @@ -138,50 +200,32 @@ def run_training():
return process.pid


def get_process_status(pid):
try:
process = psutil.Process(pid)
return process.status()
except psutil.NoSuchProcess:
return "No process found with PID: {}".format(pid)


def kill_process(pid):
try:
parent_process = psutil.Process(pid)
children = parent_process.children(recursive=True) # This will get all the child processes recursively

# First, terminate the child processes
for child in children:
child.terminate()

# Wait for the child processes to terminate, and kill them if they don't
gone, still_alive = psutil.wait_procs(children, timeout=3)
for child in still_alive:
child.kill()

# Now, terminate the parent process
parent_process.terminate()
parent_process.wait(timeout=5)

logger.info(f"Process with pid {pid} and its children have been killed")
return f"Process with pid {pid} and its children have been killed"

except psutil.NoSuchProcess:
logger.info(f"No process found with pid {pid}")
return f"No process found with pid {pid}"

except psutil.TimeoutExpired:
logger.info(f"Process {pid} or one of its children has not terminated in time")
return f"Process {pid} or one of its children has not terminated in time"


@api.on_event("startup")
async def startup_event():
@asynccontextmanager
async def lifespan(app: FastAPI):
process_pid = run_training()
logger.info(f"Started training with PID {process_pid}")
global PID
PID = process_pid
asyncio.create_task(runner.run_main())
# background_tasks.add_task(monitor_training_process, PID)
yield


api = FastAPI(lifespan=lifespan)
logger.info(f"AUTOTRAIN_USERNAME: {AUTOTRAIN_USERNAME}")
logger.info(f"PROJECT_NAME: {PROJECT_NAME}")
logger.info(f"TASK_ID: {TASK_ID}")
logger.info(f"DATA_PATH: {DATA_PATH}")
logger.info(f"MODEL: {MODEL}")
logger.info(f"OUTPUT_MODEL_REPO: {OUTPUT_MODEL_REPO}")


# @api.on_event("startup")
# async def startup_event():
# process_pid = run_training()
# logger.info(f"Started training with PID {process_pid}")
# global PID
# PID = process_pid


@api.get("/")
Expand Down
7 changes: 6 additions & 1 deletion src/autotrain/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

HF_TOKEN = os.environ.get("HF_TOKEN", None)
_, _, USERS = app_utils.user_validation()
ENABLE_NGC = int(os.environ.get("ENABLE_NGC", 0))


HIDDEN_PARAMS = [
Expand Down Expand Up @@ -132,7 +133,11 @@ async def read_form(request: Request):
"""
if HF_TOKEN is None:
return templates.TemplateResponse("error.html", {"request": request})
context = {"request": request, "valid_users": USERS}
context = {
"request": request,
"valid_users": USERS,
"enable_ngc": ENABLE_NGC,
}
return templates.TemplateResponse("index.html", context)


Expand Down
34 changes: 32 additions & 2 deletions src/autotrain/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,13 @@ class NGCRunner:
def __post_init__(self):
self.ngc_ace = os.environ.get("NGC_ACE")
self.ngc_org = os.environ.get("NGC_ORG")
self.ngc_api_key = os.environ.get("NGC_CLI_API_KEY")
self.ngc_team = os.environ.get("NGC_TEAM")
self.instance_map = {
"dgx-a100": "dgxa100.80g.1.norm",
"dgx-2a100": "dgxa100.80g.2.norm",
"dgx-4a100": "dgxa100.80g.4.norm",
"dgx-8a100": "dgxa100.80g.8.norm",
}
logger.info("Creating NGC Job")
logger.info(f"NGC_ACE: {self.ngc_ace}")
Expand All @@ -455,5 +460,30 @@ def create(self):
for k, v in self.env_vars.items():
cmd += f" --env-var {k}:{v}"

# run using subprocess, wait for completion
subprocess.run(cmd, shell=True, check=True)
ngc_config_cmd = "ngc config set"
ngc_config_cmd += " --team {ngc_team} --org {ngc_org} --ace {ngc_ace}"
ngc_config_cmd = ngc_config_cmd.format(
# ngc_api_key=self.ngc_api_key,
ngc_team=self.ngc_team,
ngc_org=self.ngc_org,
ngc_ace=self.ngc_ace,
)
logger.info("Setting NGC API key")
ngc_config_process = subprocess.Popen(ngc_config_cmd, shell=True)
ngc_config_process.wait()

if ngc_config_process.returncode == 0:
logger.info("NGC API key set successfully")
else:
logger.error("Failed to set NGC API key")
# print full output
logger.error(ngc_config_process.stdout.read())
logger.error(ngc_config_process.stderr.read())
raise Exception("Failed to set NGC API key")

logger.info("Creating NGC Job")
subprocess.run(
cmd,
shell=True,
check=True,
)
4 changes: 4 additions & 0 deletions src/autotrain/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def __post_init__(self):
"T4 Small": "spaces-t4s",
"CPU Upgrade": "spaces-cpu",
"CPU (Free)": "spaces-cpuf",
"DGX 1xA100": "dgx-a100",
"DGX 2xA100": "dgx-2a100",
"DGX 4xA100": "dgx-4a100",
"DGX 8xA100": "dgx-8a100",
# "Local": "local",
# "AutoTrain": "autotrain",
}
Expand Down
4 changes: 4 additions & 0 deletions src/autotrain/trainers/clm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,12 @@ def train(config):
trust_remote_code=True,
use_flash_attention_2=config.use_flash_attention_2,
)
else:
model_ref = None

model.resize_token_embeddings(len(tokenizer))
if model_ref is not None:
model_ref.resize_token_embeddings(len(tokenizer))

if config.use_peft:
if config.use_int8 or config.use_int4:
Expand Down
3 changes: 2 additions & 1 deletion src/autotrain/trainers/clm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# Usage

```python

from transformers import AutoModelForCausalLM, AutoTokenizer

model_path = "PATH_TO_THIS_REPO"
Expand All @@ -58,7 +59,7 @@

# Model response: "Hello! How can I assist you today?"
print(response)

```

"""

Expand Down
12 changes: 8 additions & 4 deletions templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,14 @@
<option value="CPU Upgrade">CPU Upgrade</option>
<option value="CPU (Free)">CPU (Free)</option>
</optgroup>
<!-- <optgroup label="DGX Cloud">
<option value="A100 Large">4xA100 DGX</option>
<option value="A10G Large">2xA100 DGX</option>
</optgroup> -->
{% if enable_ngc == 1 %}
<optgroup label="DGX Cloud">
<option value="DGX 1xA100">1xA100 DGX</option>
<option value="DGX 2xA100">2xA100 DGX</option>
<option value="DGX 4xA100">4xA100 DGX</option>
<option value="DGX 8xA100">8xA100 DGX</option>
</optgroup>
{% endif %}
</select>
</div>
</div>
Expand Down
Loading