Skip to content

Commit

Permalink
Client (#801)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur authored Nov 7, 2024
1 parent 2a39441 commit 1761163
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 88 deletions.
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@

TESTS_REQUIRE = ["pytest"]

CLIENT_REQUIRES = ["requests", "loguru"]


EXTRAS_REQUIRE = {
"base": INSTALL_REQUIRES,
"dev": INSTALL_REQUIRES + QUALITY_REQUIRE + TESTS_REQUIRE,
"quality": INSTALL_REQUIRES + QUALITY_REQUIRE,
"docs": INSTALL_REQUIRES
Expand All @@ -45,6 +48,7 @@
"sphinx-rtd-theme==0.4.3",
"sphinx-copybutton",
],
"client": CLIENT_REQUIRES,
}

setup(
Expand Down
10 changes: 7 additions & 3 deletions src/autotrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@

import warnings

import torch._dynamo

from autotrain.logging import Logger
try:
import torch._dynamo

torch._dynamo.config.suppress_errors = True
except ImportError:
pass

from autotrain.logging import Logger

torch._dynamo.config.suppress_errors = True

warnings.filterwarnings("ignore", category=UserWarning, module="tensorflow")
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
Expand Down
80 changes: 59 additions & 21 deletions src/autotrain/app/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import JSONResponse
from huggingface_hub import HfApi
from huggingface_hub import HfApi, constants
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
from pydantic import BaseModel, create_model, model_validator

from autotrain import __version__, logger
Expand Down Expand Up @@ -569,6 +570,10 @@ def validate_params(cls, values):
return values


class JobIDModel(BaseModel):
jid: str


api_router = APIRouter()


Expand Down Expand Up @@ -690,33 +695,16 @@ async def api_version():
return {"version": __version__}


@api_router.get("/logs", response_class=JSONResponse)
async def api_logs(job_id: str, token: bool = Depends(api_auth)):
"""
Fetch logs for a specific job.
Args:
job_id (str): The ID of the job for which logs are to be fetched.
token (bool, optional): Authentication token, defaults to the result of api_auth dependency.
Returns:
dict: A dictionary containing the logs, success status, and a message.
"""
# project = AutoTrainProject(job_id=job_id, token=token)
# logs = project.get_logs()
return {"logs": "Not implemented yet", "success": False, "message": "Not implemented yet"}


@api_router.get("/stop_training", response_class=JSONResponse)
async def api_stop_training(job_id: str, token: bool = Depends(api_auth)):
@api_router.post("/stop_training", response_class=JSONResponse)
async def api_stop_training(job: JobIDModel, token: bool = Depends(api_auth)):
"""
Stops the training job with the given job ID.
This asynchronous function pauses the training job identified by the provided job ID.
It uses the Hugging Face API to pause the space associated with the job.
Args:
job_id (str): The ID of the job to stop.
job (JobIDModel): The job model containing the job ID.
token (bool, optional): The authentication token, provided by dependency injection.
Returns:
Expand All @@ -728,9 +716,59 @@ async def api_stop_training(job_id: str, token: bool = Depends(api_auth)):
Exception: If there is an error while attempting to stop the training job.
"""
hf_api = HfApi(token=token)
job_id = job.jid
try:
hf_api.pause_space(repo_id=job_id)
except Exception as e:
logger.error(f"Failed to stop training: {e}")
return {"message": f"Failed to stop training for {job_id}: {e}", "success": False}
return {"message": f"Training stopped for {job_id}", "success": True}


@api_router.post("/logs", response_class=JSONResponse)
async def api_logs(job: JobIDModel, token: bool = Depends(api_auth)):
"""
Fetch logs for a given job.
This endpoint retrieves logs for a specified job by its job ID. It first obtains a JWT token
to authenticate the request and then fetches the logs from the Hugging Face API.
Args:
job (JobIDModel): The job model containing the job ID.
token (bool, optional): Dependency injection for API authentication. Defaults to Depends(api_auth).
Returns:
JSONResponse: A JSON response containing the logs, success status, and a message.
Raises:
Exception: If there is an error fetching the logs, the exception message is returned in the response.
"""
job_id = job.jid
jwt_url = f"{constants.ENDPOINT}/api/spaces/{job_id}/jwt"
response = get_session().get(jwt_url, headers=build_hf_headers())
hf_raise_for_status(response)
jwt_token = response.json()["token"] # works for 24h (see "exp" field)

# fetch the logs
logs_url = f"https://api.hf.space/v1/{job_id}/logs/run"

_logs = []
try:
with get_session().get(logs_url, headers=build_hf_headers(token=jwt_token), stream=True) as response:
hf_raise_for_status(response)
for line in response.iter_lines():
if not line.startswith(b"data: "):
continue
line_data = line[len(b"data: ") :]
try:
event = json.loads(line_data.decode())
except json.JSONDecodeError:
continue # ignore (for example, empty lines or `b': keep-alive'`)
_logs.append((event["timestamp"], event["data"]))

# convert logs to a string
_logs = "\n".join([f"{timestamp}: {data}" for timestamp, data in _logs])

return {"logs": _logs, "success": True, "message": "Logs fetched successfully"}
except Exception as e:
return {"logs": str(e), "success": False, "message": "Failed to fetch logs"}
33 changes: 22 additions & 11 deletions src/autotrain/app/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ def _munge_common_params(self):
def _munge_params_sent_transformers(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
_params["log"] = "tensorboard"
if "log" not in _params:
_params["log"] = "tensorboard"
if not self.using_hub_dataset:
_params["sentence1_column"] = "autotrain_sentence1"
_params["sentence2_column"] = "autotrain_sentence2"
Expand Down Expand Up @@ -291,7 +292,8 @@ def _munge_params_llm(self):
"rejected_text" if not self.api else "rejected_text_column", "rejected_text"
)
_params["train_split"] = self.train_split
_params["log"] = "tensorboard"
if "log" not in _params:
_params["log"] = "tensorboard"

trainer = self.task.split(":")[1]
if trainer != "generic":
Expand Down Expand Up @@ -321,7 +323,8 @@ def _munge_params_vlm(self):
)
_params["train_split"] = self.train_split
_params["valid_split"] = self.valid_split
_params["log"] = "tensorboard"
if "log" not in _params:
_params["log"] = "tensorboard"

trainer = self.task.split(":")[1]
_params["trainer"] = trainer.lower()
Expand All @@ -335,7 +338,8 @@ def _munge_params_vlm(self):
def _munge_params_text_clf(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
_params["log"] = "tensorboard"
if "log" not in _params:
_params["log"] = "tensorboard"
if not self.using_hub_dataset:
_params["text_column"] = "autotrain_text"
_params["target_column"] = "autotrain_label"
Expand All @@ -350,7 +354,8 @@ def _munge_params_text_clf(self):
def _munge_params_extractive_qa(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
_params["log"] = "tensorboard"
if "log" not in _params:
_params["log"] = "tensorboard"
if not self.using_hub_dataset:
_params["text_column"] = "autotrain_text"
_params["question_column"] = "autotrain_question"
Expand All @@ -369,7 +374,8 @@ def _munge_params_extractive_qa(self):
def _munge_params_text_reg(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
_params["log"] = "tensorboard"
if "log" not in _params:
_params["log"] = "tensorboard"
if not self.using_hub_dataset:
_params["text_column"] = "autotrain_text"
_params["target_column"] = "autotrain_label"
Expand All @@ -384,7 +390,8 @@ def _munge_params_text_reg(self):
def _munge_params_token_clf(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
_params["log"] = "tensorboard"
if "log" not in _params:
_params["log"] = "tensorboard"
if not self.using_hub_dataset:
_params["tokens_column"] = "autotrain_text"
_params["tags_column"] = "autotrain_label"
Expand All @@ -400,7 +407,8 @@ def _munge_params_token_clf(self):
def _munge_params_seq2seq(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
_params["log"] = "tensorboard"
if "log" not in _params:
_params["log"] = "tensorboard"
if not self.using_hub_dataset:
_params["text_column"] = "autotrain_text"
_params["target_column"] = "autotrain_label"
Expand All @@ -416,7 +424,8 @@ def _munge_params_seq2seq(self):
def _munge_params_img_clf(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
_params["log"] = "tensorboard"
if "log" not in _params:
_params["log"] = "tensorboard"
if not self.using_hub_dataset:
_params["image_column"] = "autotrain_image"
_params["target_column"] = "autotrain_label"
Expand All @@ -432,7 +441,8 @@ def _munge_params_img_clf(self):
def _munge_params_img_reg(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
_params["log"] = "tensorboard"
if "log" not in _params:
_params["log"] = "tensorboard"
if not self.using_hub_dataset:
_params["image_column"] = "autotrain_image"
_params["target_column"] = "autotrain_label"
Expand All @@ -448,7 +458,8 @@ def _munge_params_img_reg(self):
def _munge_params_img_obj_det(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
_params["log"] = "tensorboard"
if "log" not in _params:
_params["log"] = "tensorboard"
if not self.using_hub_dataset:
_params["image_column"] = "autotrain_image"
_params["objects_column"] = "autotrain_objects"
Expand Down
2 changes: 2 additions & 0 deletions src/autotrain/backends/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def _create_readme(self):
_readme += "colorTo: indigo\n"
_readme += "sdk: docker\n"
_readme += "pinned: false\n"
_readme += "tags:\n"
_readme += "- autotrain\n"
_readme += "duplicated_from: autotrain-projects/autotrain-advanced\n"
_readme += "---\n"
_readme = io.BytesIO(_readme.encode())
Expand Down
Loading

0 comments on commit 1761163

Please sign in to comment.