From efdaf855a65bcf658213d2984262854720ab4c4f Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Sat, 2 Dec 2023 13:24:34 +0100 Subject: [PATCH] enable local training --- .gitignore | 1 + src/autotrain/api.py | 12 ++----- src/autotrain/app.py | 33 +++++++++++++++++-- src/autotrain/app_utils.py | 67 ++++++++++++++++++++++++++++++++++---- src/autotrain/backend.py | 40 +++++------------------ src/autotrain/db.py | 32 ++++++++++++++++++ templates/index.html | 5 ++- 7 files changed, 139 insertions(+), 51 deletions(-) create mode 100644 src/autotrain/db.py diff --git a/.gitignore b/.gitignore index 1fc6dd0f60..e1fdb5d93d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ output/ output2/ logs/ op_*/ +autotrain.db # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/src/autotrain/api.py b/src/autotrain/api.py index b1feada77b..8e78a525c9 100644 --- a/src/autotrain/api.py +++ b/src/autotrain/api.py @@ -1,7 +1,6 @@ import asyncio import os import signal -import subprocess import time from contextlib import asynccontextmanager @@ -21,8 +20,6 @@ MODEL = os.environ.get("MODEL") OUTPUT_MODEL_REPO = os.environ.get("OUTPUT_MODEL_REPO") PID = None -API_PORT = os.environ.get("API_PORT", None) -logger.info(f"API_PORT: {API_PORT}") class BackgroundRunner: @@ -32,11 +29,7 @@ async def run_main(self): status = status.strip().lower() if status in ("completed", "error", "zombie"): logger.info("Training process finished. Shutting down the server.") - time.sleep(5) - if API_PORT is not None: - subprocess.run(f"fuser -k {API_PORT}/tcp", shell=True, check=True) - else: - kill_process(os.getpid()) + kill_process(os.getpid()) break time.sleep(5) @@ -99,7 +92,6 @@ async def lifespan(app: FastAPI): global PID PID = process_pid asyncio.create_task(runner.run_main()) - # background_tasks.add_task(monitor_training_process, PID) yield @@ -118,7 +110,7 @@ async def root(): @api.get("/status") -async def status(): +async def app_status(): return get_process_status(pid=PID) diff --git a/src/autotrain/app.py b/src/autotrain/app.py index b97f3cfda7..f2f308b9e6 100644 --- a/src/autotrain/app.py +++ b/src/autotrain/app.py @@ -3,13 +3,14 @@ from typing import List import pandas as pd -from fastapi import FastAPI, File, Form, Request, UploadFile +from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from autotrain import app_utils, logger from autotrain.dataset import AutoTrainDataset, AutoTrainDreamboothDataset, AutoTrainImageClassificationDataset +from autotrain.db import AutoTrainDB from autotrain.project import AutoTrainProject from autotrain.trainers.clm.params import LLMTrainingParams from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams @@ -22,6 +23,8 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None) _, _, USERS = app_utils.user_validation() ENABLE_NGC = int(os.environ.get("ENABLE_NGC", 0)) +DB = AutoTrainDB("autotrain.db") +AUTOTRAIN_LOCAL = int(os.environ.get("AUTOTRAIN_LOCAL", 0)) HIDDEN_PARAMS = [ @@ -137,6 +140,7 @@ async def read_form(request: Request): "request": request, "valid_users": USERS, "enable_ngc": ENABLE_NGC, + "enable_local": AUTOTRAIN_LOCAL, } return templates.TemplateResponse("index.html", context) @@ -202,8 +206,30 @@ async def handle_form( """ This function is used to handle the form submission """ + logger.info(f"hardware: {hardware}") + if hardware == "Local": + running_jobs = DB.get_running_jobs() + logger.info(f"Running jobs: {running_jobs}") + if running_jobs: + for _pid in running_jobs: + logger.info(f"Killing PID: {_pid}") + proc_status = app_utils.get_process_status(_pid) + proc_status = proc_status.strip().lower() + if proc_status in ("completed", "error", "zombie"): + logger.info(f"Process {_pid} is already completed. Skipping...") + try: + app_utils.kill_process_by_pid(_pid) + except Exception as e: + logger.info(f"Error while killing process: {e}") + DB.delete_job(_pid) + + running_jobs = DB.get_running_jobs() + if running_jobs: + logger.info(f"Running jobs: {running_jobs}") + raise HTTPException( + status_code=409, detail="Another job is already running. Please wait for it to finish." + ) - # if HF_TOKEN is None is None, return error if HF_TOKEN is None: return {"error": "HF_TOKEN not set"} @@ -304,4 +330,7 @@ async def handle_form( jobs_df = pd.DataFrame([params]) project = AutoTrainProject(dataset=dset, job_params=jobs_df) ids = project.create() + if hardware == "Local": + for _id in ids: + DB.add_job(_id) return {"success": "true", "space_ids": ids} diff --git a/src/autotrain/app_utils.py b/src/autotrain/app_utils.py index 995f5cfca1..e94e5e3e06 100644 --- a/src/autotrain/app_utils.py +++ b/src/autotrain/app_utils.py @@ -1,7 +1,10 @@ import json import os +import signal +import socket import subprocess +import psutil import requests from autotrain import config, logger @@ -13,6 +16,41 @@ from autotrain.trainers.text_classification.params import TextClassificationParams +def get_process_status(pid): + try: + process = psutil.Process(pid) + proc_status = process.status() + logger.info(f"Process status: {proc_status}") + return proc_status + except psutil.NoSuchProcess: + logger.info(f"No process found with PID: {pid}") + return "Completed" + + +def find_pid_by_port(port): + """Find PID by port number.""" + try: + result = subprocess.run(["lsof", "-i", f":{port}", "-t"], capture_output=True, text=True, check=True) + pids = result.stdout.strip().split("\n") + return [int(pid) for pid in pids if pid.isdigit()] + except subprocess.CalledProcessError: + return [] + + +def kill_process_by_pid(pid): + """Kill process by PID.""" + os.kill(pid, signal.SIGTERM) + + +def is_port_in_use(port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 + + +def kill_process_on_port(port): + os.system(f"fuser -k {port}/tcp") + + def user_authentication(token): logger.info("Authenticating user...") headers = {} @@ -64,12 +102,12 @@ def user_validation(): return user_token, valid_can_pay, who_is_training -def run_training(params, task_id): +def run_training(params, task_id, local=False): params = json.loads(params) logger.info(params) if task_id == 9: params = LLMTrainingParams(**params) - if os.environ.get("API_PORT") is None: + if not local: params.project_name = "/tmp/model" else: params.project_name = os.path.join("output", params.project_name) @@ -91,7 +129,10 @@ def run_training(params, task_id): ) elif task_id == 28: params = Seq2SeqParams(**params) - params.project_name = "/tmp/model" + if not local: + params.project_name = "/tmp/model" + else: + params.project_name = os.path.join("output", params.project_name) params.save(output_dir=params.project_name) cmd = ["accelerate", "launch", "--num_machines", "1", "--num_processes", "1"] cmd.append("--mixed_precision") @@ -110,7 +151,10 @@ def run_training(params, task_id): ) elif task_id in (1, 2): params = TextClassificationParams(**params) - params.project_name = "/tmp/model" + if not local: + params.project_name = "/tmp/model" + else: + params.project_name = os.path.join("output", params.project_name) params.save(output_dir=params.project_name) cmd = ["accelerate", "launch", "--num_machines", "1", "--num_processes", "1"] cmd.append("--mixed_precision") @@ -129,7 +173,10 @@ def run_training(params, task_id): ) elif task_id in (13, 14, 15, 16, 26): params = TabularParams(**params) - params.project_name = "/tmp/model" + if not local: + params.project_name = "/tmp/model" + else: + params.project_name = os.path.join("output", params.project_name) params.save(output_dir=params.project_name) cmd = [ "python", @@ -140,7 +187,10 @@ def run_training(params, task_id): ] elif task_id == 27: params = GenericParams(**params) - params.project_name = "/tmp/model" + if not local: + params.project_name = "/tmp/model" + else: + params.project_name = os.path.join("output", params.project_name) params.save(output_dir=params.project_name) cmd = [ "python", @@ -151,7 +201,10 @@ def run_training(params, task_id): ] elif task_id == 25: params = DreamBoothTrainingParams(**params) - params.project_name = "/tmp/model" + if not local: + params.project_name = "/tmp/model" + else: + params.project_name = os.path.join("output", params.project_name) params.save(output_dir=params.project_name) cmd = [ "python", diff --git a/src/autotrain/backend.py b/src/autotrain/backend.py index 5e01d5d299..95cb96f01f 100644 --- a/src/autotrain/backend.py +++ b/src/autotrain/backend.py @@ -8,6 +8,7 @@ from huggingface_hub import HfApi from autotrain import logger +from autotrain.app_utils import run_training from autotrain.dataset import AutoTrainDataset, AutoTrainDreamboothDataset from autotrain.trainers.clm.params import LLMTrainingParams from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams @@ -391,10 +392,12 @@ def _create_space(self): backend=self.backend, ) ngc_runner.create() + return else: local_runner = LocalRunner(env_vars=env_vars) - local_runner.create() - return + pid = local_runner.create() + return pid + api = HfApi(token=self.params.token) repo_id = f"{self.username}/autotrain-{self.params.project_name}" api.create_repo( @@ -423,41 +426,16 @@ def _create_space(self): return repo_id -def run_server(): - import uvicorn - - from autotrain.api import api - - uvicorn.run(api, host="0.0.0.0", port=17860) - - @dataclass class LocalRunner: env_vars: dict def create(self): logger.info("Starting server") - for key, value in self.env_vars.items(): - os.environ[key] = value - os.environ["API_PORT"] = "17860" - import threading - - thread = threading.Thread(target=run_server) - thread.start() - print("Server is running in a separate thread") - # cmd = "autotrain api --port 17860 --host 0.0.0.0" - # # start the server in the background as a new process - # logger.info("Starting server") - # proc = subprocess.Popen(cmd, shell=True, env=self.env_vars) - # proc.wait() - # if proc.returncode == 0: - # logger.info("Server started successfully") - # else: - # logger.error("Failed to start server") - # # print full output - # logger.error(proc.stdout.read()) - # logger.error(proc.stderr.read()) - # raise Exception("Failed to start server") + params = self.env_vars["PARAMS"] + task_id = int(self.env_vars["TASK_ID"]) + training_pid = run_training(params, task_id, local=True) + return training_pid @dataclass diff --git a/src/autotrain/db.py b/src/autotrain/db.py new file mode 100644 index 0000000000..cd8a4e5298 --- /dev/null +++ b/src/autotrain/db.py @@ -0,0 +1,32 @@ +import sqlite3 + + +class AutoTrainDB: + def __init__(self, db_path): + self.db_path = db_path + self.conn = sqlite3.connect(db_path) + self.c = self.conn.cursor() + self.create_jobs_table() + + def create_jobs_table(self): + self.c.execute( + """CREATE TABLE IF NOT EXISTS jobs + (id INTEGER PRIMARY KEY, pid INTEGER)""" + ) + self.conn.commit() + + def add_job(self, pid): + sql = f"INSERT INTO jobs (pid) VALUES ({pid})" + self.c.execute(sql) + self.conn.commit() + + def get_running_jobs(self): + self.c.execute("""SELECT pid FROM jobs""") + running_pids = self.c.fetchall() + running_pids = [pid[0] for pid in running_pids] + return running_pids + + def delete_job(self, pid): + sql = f"DELETE FROM jobs WHERE pid={pid}" + self.c.execute(sql) + self.conn.commit() diff --git a/templates/index.html b/templates/index.html index f670de0440..ef706736ea 100644 --- a/templates/index.html +++ b/templates/index.html @@ -108,6 +108,9 @@