Skip to content

Commit

Permalink
enable local training
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Dec 2, 2023
1 parent a595cff commit efdaf85
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 51 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ output/
output2/
logs/
op_*/
autotrain.db

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
12 changes: 2 additions & 10 deletions src/autotrain/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import os
import signal
import subprocess
import time
from contextlib import asynccontextmanager

Expand All @@ -21,8 +20,6 @@
MODEL = os.environ.get("MODEL")
OUTPUT_MODEL_REPO = os.environ.get("OUTPUT_MODEL_REPO")
PID = None
API_PORT = os.environ.get("API_PORT", None)
logger.info(f"API_PORT: {API_PORT}")


class BackgroundRunner:
Expand All @@ -32,11 +29,7 @@ async def run_main(self):
status = status.strip().lower()
if status in ("completed", "error", "zombie"):
logger.info("Training process finished. Shutting down the server.")
time.sleep(5)
if API_PORT is not None:
subprocess.run(f"fuser -k {API_PORT}/tcp", shell=True, check=True)
else:
kill_process(os.getpid())
kill_process(os.getpid())
break
time.sleep(5)

Expand Down Expand Up @@ -99,7 +92,6 @@ async def lifespan(app: FastAPI):
global PID
PID = process_pid
asyncio.create_task(runner.run_main())
# background_tasks.add_task(monitor_training_process, PID)
yield


Expand All @@ -118,7 +110,7 @@ async def root():


@api.get("/status")
async def status():
async def app_status():
return get_process_status(pid=PID)


Expand Down
33 changes: 31 additions & 2 deletions src/autotrain/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from typing import List

import pandas as pd
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates

from autotrain import app_utils, logger
from autotrain.dataset import AutoTrainDataset, AutoTrainDreamboothDataset, AutoTrainImageClassificationDataset
from autotrain.db import AutoTrainDB
from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams
Expand All @@ -22,6 +23,8 @@
HF_TOKEN = os.environ.get("HF_TOKEN", None)
_, _, USERS = app_utils.user_validation()
ENABLE_NGC = int(os.environ.get("ENABLE_NGC", 0))
DB = AutoTrainDB("autotrain.db")
AUTOTRAIN_LOCAL = int(os.environ.get("AUTOTRAIN_LOCAL", 0))


HIDDEN_PARAMS = [
Expand Down Expand Up @@ -137,6 +140,7 @@ async def read_form(request: Request):
"request": request,
"valid_users": USERS,
"enable_ngc": ENABLE_NGC,
"enable_local": AUTOTRAIN_LOCAL,
}
return templates.TemplateResponse("index.html", context)

Expand Down Expand Up @@ -202,8 +206,30 @@ async def handle_form(
"""
This function is used to handle the form submission
"""
logger.info(f"hardware: {hardware}")
if hardware == "Local":
running_jobs = DB.get_running_jobs()
logger.info(f"Running jobs: {running_jobs}")
if running_jobs:
for _pid in running_jobs:
logger.info(f"Killing PID: {_pid}")
proc_status = app_utils.get_process_status(_pid)
proc_status = proc_status.strip().lower()
if proc_status in ("completed", "error", "zombie"):
logger.info(f"Process {_pid} is already completed. Skipping...")
try:
app_utils.kill_process_by_pid(_pid)
except Exception as e:
logger.info(f"Error while killing process: {e}")
DB.delete_job(_pid)

running_jobs = DB.get_running_jobs()
if running_jobs:
logger.info(f"Running jobs: {running_jobs}")
raise HTTPException(
status_code=409, detail="Another job is already running. Please wait for it to finish."
)

# if HF_TOKEN is None is None, return error
if HF_TOKEN is None:
return {"error": "HF_TOKEN not set"}

Expand Down Expand Up @@ -304,4 +330,7 @@ async def handle_form(
jobs_df = pd.DataFrame([params])
project = AutoTrainProject(dataset=dset, job_params=jobs_df)
ids = project.create()
if hardware == "Local":
for _id in ids:
DB.add_job(_id)
return {"success": "true", "space_ids": ids}
67 changes: 60 additions & 7 deletions src/autotrain/app_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
import os
import signal
import socket
import subprocess

import psutil
import requests

from autotrain import config, logger
Expand All @@ -13,6 +16,41 @@
from autotrain.trainers.text_classification.params import TextClassificationParams


def get_process_status(pid):
try:
process = psutil.Process(pid)
proc_status = process.status()
logger.info(f"Process status: {proc_status}")
return proc_status
except psutil.NoSuchProcess:
logger.info(f"No process found with PID: {pid}")
return "Completed"


def find_pid_by_port(port):
"""Find PID by port number."""
try:
result = subprocess.run(["lsof", "-i", f":{port}", "-t"], capture_output=True, text=True, check=True)
pids = result.stdout.strip().split("\n")
return [int(pid) for pid in pids if pid.isdigit()]
except subprocess.CalledProcessError:
return []


def kill_process_by_pid(pid):
"""Kill process by PID."""
os.kill(pid, signal.SIGTERM)


def is_port_in_use(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0


def kill_process_on_port(port):
os.system(f"fuser -k {port}/tcp")


def user_authentication(token):
logger.info("Authenticating user...")
headers = {}
Expand Down Expand Up @@ -64,12 +102,12 @@ def user_validation():
return user_token, valid_can_pay, who_is_training


def run_training(params, task_id):
def run_training(params, task_id, local=False):
params = json.loads(params)
logger.info(params)
if task_id == 9:
params = LLMTrainingParams(**params)
if os.environ.get("API_PORT") is None:
if not local:
params.project_name = "/tmp/model"
else:
params.project_name = os.path.join("output", params.project_name)
Expand All @@ -91,7 +129,10 @@ def run_training(params, task_id):
)
elif task_id == 28:
params = Seq2SeqParams(**params)
params.project_name = "/tmp/model"
if not local:
params.project_name = "/tmp/model"
else:
params.project_name = os.path.join("output", params.project_name)
params.save(output_dir=params.project_name)
cmd = ["accelerate", "launch", "--num_machines", "1", "--num_processes", "1"]
cmd.append("--mixed_precision")
Expand All @@ -110,7 +151,10 @@ def run_training(params, task_id):
)
elif task_id in (1, 2):
params = TextClassificationParams(**params)
params.project_name = "/tmp/model"
if not local:
params.project_name = "/tmp/model"
else:
params.project_name = os.path.join("output", params.project_name)
params.save(output_dir=params.project_name)
cmd = ["accelerate", "launch", "--num_machines", "1", "--num_processes", "1"]
cmd.append("--mixed_precision")
Expand All @@ -129,7 +173,10 @@ def run_training(params, task_id):
)
elif task_id in (13, 14, 15, 16, 26):
params = TabularParams(**params)
params.project_name = "/tmp/model"
if not local:
params.project_name = "/tmp/model"
else:
params.project_name = os.path.join("output", params.project_name)
params.save(output_dir=params.project_name)
cmd = [
"python",
Expand All @@ -140,7 +187,10 @@ def run_training(params, task_id):
]
elif task_id == 27:
params = GenericParams(**params)
params.project_name = "/tmp/model"
if not local:
params.project_name = "/tmp/model"
else:
params.project_name = os.path.join("output", params.project_name)
params.save(output_dir=params.project_name)
cmd = [
"python",
Expand All @@ -151,7 +201,10 @@ def run_training(params, task_id):
]
elif task_id == 25:
params = DreamBoothTrainingParams(**params)
params.project_name = "/tmp/model"
if not local:
params.project_name = "/tmp/model"
else:
params.project_name = os.path.join("output", params.project_name)
params.save(output_dir=params.project_name)
cmd = [
"python",
Expand Down
40 changes: 9 additions & 31 deletions src/autotrain/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from huggingface_hub import HfApi

from autotrain import logger
from autotrain.app_utils import run_training
from autotrain.dataset import AutoTrainDataset, AutoTrainDreamboothDataset
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams
Expand Down Expand Up @@ -391,10 +392,12 @@ def _create_space(self):
backend=self.backend,
)
ngc_runner.create()
return
else:
local_runner = LocalRunner(env_vars=env_vars)
local_runner.create()
return
pid = local_runner.create()
return pid

api = HfApi(token=self.params.token)
repo_id = f"{self.username}/autotrain-{self.params.project_name}"
api.create_repo(
Expand Down Expand Up @@ -423,41 +426,16 @@ def _create_space(self):
return repo_id


def run_server():
import uvicorn

from autotrain.api import api

uvicorn.run(api, host="0.0.0.0", port=17860)


@dataclass
class LocalRunner:
env_vars: dict

def create(self):
logger.info("Starting server")
for key, value in self.env_vars.items():
os.environ[key] = value
os.environ["API_PORT"] = "17860"
import threading

thread = threading.Thread(target=run_server)
thread.start()
print("Server is running in a separate thread")
# cmd = "autotrain api --port 17860 --host 0.0.0.0"
# # start the server in the background as a new process
# logger.info("Starting server")
# proc = subprocess.Popen(cmd, shell=True, env=self.env_vars)
# proc.wait()
# if proc.returncode == 0:
# logger.info("Server started successfully")
# else:
# logger.error("Failed to start server")
# # print full output
# logger.error(proc.stdout.read())
# logger.error(proc.stderr.read())
# raise Exception("Failed to start server")
params = self.env_vars["PARAMS"]
task_id = int(self.env_vars["TASK_ID"])
training_pid = run_training(params, task_id, local=True)
return training_pid


@dataclass
Expand Down
32 changes: 32 additions & 0 deletions src/autotrain/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import sqlite3


class AutoTrainDB:
def __init__(self, db_path):
self.db_path = db_path
self.conn = sqlite3.connect(db_path)
self.c = self.conn.cursor()
self.create_jobs_table()

def create_jobs_table(self):
self.c.execute(
"""CREATE TABLE IF NOT EXISTS jobs
(id INTEGER PRIMARY KEY, pid INTEGER)"""
)
self.conn.commit()

def add_job(self, pid):
sql = f"INSERT INTO jobs (pid) VALUES ({pid})"
self.c.execute(sql)
self.conn.commit()

def get_running_jobs(self):
self.c.execute("""SELECT pid FROM jobs""")
running_pids = self.c.fetchall()
running_pids = [pid[0] for pid in running_pids]
return running_pids

def delete_job(self, pid):
sql = f"DELETE FROM jobs WHERE pid={pid}"
self.c.execute(sql)
self.conn.commit()
5 changes: 4 additions & 1 deletion templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@
<label for="hardware" class="text-sm font-medium text-gray-700">Hardware</label>
<select id="hardware" name="hardware"
class="mt-1 block w-full border border-gray-300 px-3 py-2 bg-white rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500">
{% if enable_local == 1 %}
<option value="Local">Local</option>
{% else %}
<optgroup label="Hugging Face Spaces">
<option value="A10G Large">A10G Large</option>
<option value="A100 Large">A100 Large</option>
Expand All @@ -116,7 +119,6 @@
<option value="T4 Small">T4 Small</option>
<option value="CPU Upgrade">CPU Upgrade</option>
<option value="CPU (Free)">CPU (Free)</option>
<option value="Local">Local</option>
</optgroup>
{% if enable_ngc == 1 %}
<optgroup label="DGX Cloud">
Expand All @@ -126,6 +128,7 @@
<option value="DGX 8xA100">8xA100 DGX</option>
</optgroup>
{% endif %}
{% endif %}
</select>
</div>
<div class="form-group">
Expand Down

0 comments on commit efdaf85

Please sign in to comment.