Skip to content

Commit

Permalink
ngc backend
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Nov 7, 2023
1 parent adbde1d commit bf5d928
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 6 deletions.
2 changes: 0 additions & 2 deletions Dockerfile.api

This file was deleted.

5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ docker:
docker tag autotrain-advanced:latest huggingface/autotrain-advanced:latest
docker push huggingface/autotrain-advanced:latest

ngc:
docker build -t autotrain-advanced:latest .
docker tag autotrain-advanced:latest nvcr.io/ycymhzotssoi/autotrain-advanced:latest
docker push nvcr.io/ycymhzotssoi/autotrain-advanced:latest

pip:
rm -rf build/
rm -rf dist/
Expand Down
5 changes: 4 additions & 1 deletion src/autotrain/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def run_training():
params = json.loads(PARAMS)
logger.info(params)
if TASK_ID == 9:
params = LLMTrainingParams.parse_raw(params)
try:
params = LLMTrainingParams.parse_raw(params)
except Exception:
params = LLMTrainingParams.parse_obj(params)
params.project_name = "/tmp/model"
params.save(output_dir=params.project_name)
cmd = ["accelerate", "launch", "--num_machines", "1", "--num_processes", "1"]
Expand Down
65 changes: 65 additions & 0 deletions src/autotrain/backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import json
import os
import subprocess
from dataclasses import dataclass
from typing import Union

Expand Down Expand Up @@ -262,6 +263,7 @@ def __post_init__(self):
"t4s": "t4-small",
"cpu": "cpu-upgrade",
"cpuf": "cpu-basic",
"dgx-a100": "dgxa100.80g.1.norm",
}
if not isinstance(self.params, GenericParams):
if self.params.repo_id is not None:
Expand Down Expand Up @@ -361,6 +363,30 @@ def _add_secrets(self, api, repo_id):
api.add_space_secret(repo_id=repo_id, key="OUTPUT_MODEL_REPO", value=self.params.repo_id)

def _create_space(self):
if self.backend.startswith("dgx-"):
env_vars = {
"HF_TOKEN": self.params.token,
"AUTOTRAIN_USERNAME": self.username,
"PROJECT_NAME": self.params.project_name,
"TASK_ID": str(self.task_id),
"PARAMS": json.dumps(self.params.json()),
}
if isinstance(self.params, DreamBoothTrainingParams):
env_vars["DATA_PATH"] = self.params.image_path
else:
env_vars["DATA_PATH"] = self.params.data_path

if not isinstance(self.params, GenericParams):
env_vars["MODEL"] = self.params.model
env_vars["OUTPUT_MODEL_REPO"] = self.params.repo_id

ngc_runner = NGCRunner(
job_name=self.params.repo_id.replace("/", "-"),
env_vars=env_vars,
backend=self.backend,
)
ngc_runner.create()
return
api = HfApi(token=self.params.token)
repo_id = f"{self.username}/autotrain-{self.params.project_name}"
api.create_repo(
Expand All @@ -387,3 +413,42 @@ def _create_space(self):
repo_type="space",
)
return repo_id


@dataclass
class NGCRunner:
job_name: str
env_vars: dict
backend: str

def __post_init__(self):
self.ngc_ace = os.environ.get("NGC_ACE")
self.ngc_org = os.environ.get("NGC_ORG")
self.instance_map = {
"dgx-a100": "dgxa100.80g.1.norm",
}
logger.info("Creating NGC Job")
logger.info(f"NGC_ACE: {self.ngc_ace}")
logger.info(f"NGC_ORG: {self.ngc_org}")
logger.info(f"job_name: {self.job_name}")
logger.info(f"backend: {self.backend}")

def create(self):
cmd = "ngc base-command job run --name {job_name}"
cmd += " --priority NORMAL --order 50 --preempt RUNONCE --min-timeslice 0s"
cmd += " --total-runtime 3600s --ace {ngc_ace} --org {ngc_org} --instance {instance}"
cmd += " --commandline 'set -x; conda run --no-capture-output -p /app/env autotrain api --port 7860 --host 0.0.0.0' -p 7860 --result /results"
cmd += " --image '{ngc_org}/autotrain-advanced:latest'"

cmd = cmd.format(
job_name=self.job_name,
ngc_ace=self.ngc_ace,
ngc_org=self.ngc_org,
instance=self.instance_map[self.backend],
)

for k, v in self.env_vars.items():
cmd += f" --env-var {k}:{v}"

# run using subprocess, wait for completion
subprocess.run(cmd, shell=True, check=True)
6 changes: 3 additions & 3 deletions src/autotrain/cli/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,8 @@ def __init__(self, args):
if self.args.backend.startswith("spaces") or self.args.backend.startswith("ep-"):
if not self.args.push_to_hub:
raise ValueError("Push to hub must be specified for spaces backend")
if self.args.repo_id is None:
raise ValueError("Repo id must be specified for spaces backend")
if self.args.username is None and self.args.repo_id is None:
raise ValueError("Repo id or username must be specified for spaces backend")
if self.args.token is None:
raise ValueError("Token must be specified for spaces backend")

Expand Down Expand Up @@ -534,7 +534,7 @@ def run(self):
)

# space training
if self.args.backend.startswith("spaces"):
if self.args.backend.startswith("spaces") or self.args.backend.startswith("dgx"):
logger.info("Creating space...")
sr = SpaceRunner(
params=params,
Expand Down

0 comments on commit bf5d928

Please sign in to comment.