Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Sep 27, 2024
1 parent 1d1ce2b commit e5a3b02
Show file tree
Hide file tree
Showing 66 changed files with 3,173 additions and 231 deletions.
115 changes: 104 additions & 11 deletions src/autotrain/app/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,23 @@


def create_api_base_model(base_class, class_name):
"""
Creates a new Pydantic model based on a given base class and class name,
excluding specified fields.
Args:
base_class (Type): The base Pydantic model class to extend.
class_name (str): The name of the new model class to create.
Returns:
Type: A new Pydantic model class with the specified modifications.
Notes:
- The function uses type hints from the base class to define the new model's fields.
- Certain fields are excluded from the new model based on the class name.
- The function supports different sets of hidden parameters for different class names.
- The new model's configuration is set to have no protected namespaces.
"""
annotations = get_type_hints(base_class)
if class_name in ("LLMSFTTrainingParamsAPI", "LLMRewardTrainingParamsAPI"):
more_hidden_params = [
Expand Down Expand Up @@ -206,6 +223,32 @@ class ExtractiveQuestionAnsweringColumnMapping(BaseModel):


class APICreateProjectModel(BaseModel):
"""
APICreateProjectModel is a Pydantic model that defines the schema for creating a project.
Attributes:
project_name (str): The name of the project.
task (Literal): The type of task for the project. Supported tasks include various LLM tasks,
image classification, dreambooth, seq2seq, token classification, text classification,
text regression, tabular classification, tabular regression, image regression, VLM tasks,
and extractive question answering.
base_model (str): The base model to be used for the project.
hardware (Literal): The type of hardware to be used for the project. Supported hardware options
include various configurations of spaces and local.
params (Union): The training parameters for the project. The type of parameters depends on the
task selected.
username (str): The username of the person creating the project.
column_mapping (Optional[Union]): The column mapping for the project. The type of column mapping
depends on the task selected.
hub_dataset (str): The dataset to be used for the project.
train_split (str): The training split of the dataset.
valid_split (Optional[str]): The validation split of the dataset.
Methods:
validate_column_mapping(cls, values): Validates the column mapping based on the task selected.
validate_params(cls, values): Validates the training parameters based on the task selected.
"""

project_name: str
task: Literal[
"llm:sft",
Expand Down Expand Up @@ -530,6 +573,18 @@ def validate_params(cls, values):


def api_auth(request: Request):
"""
Authenticates the API request using a Bearer token.
Args:
request (Request): The incoming HTTP request object.
Returns:
str: The verified Bearer token if authentication is successful.
Raises:
HTTPException: If the token is invalid, expired, or missing.
"""
authorization = request.headers.get("Authorization")
if authorization:
schema, _, token = authorization.partition(" ")
Expand All @@ -553,9 +608,24 @@ def api_auth(request: Request):
@api_router.post("/create_project", response_class=JSONResponse)
async def api_create_project(project: APICreateProjectModel, token: bool = Depends(api_auth)):
"""
This function is used to create a new project
:param project: APICreateProjectModel
:return: JSONResponse
Asynchronously creates a new project based on the provided parameters.
Args:
project (APICreateProjectModel): The model containing the project details and parameters.
token (bool, optional): The authentication token. Defaults to Depends(api_auth).
Returns:
dict: A dictionary containing a success message, the job ID of the created project, and a success status.
Raises:
HTTPException: If there is an error during project creation.
Notes:
- The function determines the hardware type based on the project hardware attribute.
- It logs the provided parameters and column mapping.
- It sets the appropriate parameters based on the task type.
- It updates the parameters with the provided ones and creates an AppParams instance.
- The function then creates an AutoTrainProject instance and initiates the project creation process.
"""
provided_params = project.params.model_dump()
if project.hardware == "local":
Expand Down Expand Up @@ -609,18 +679,28 @@ async def api_create_project(project: APICreateProjectModel, token: bool = Depen
@api_router.get("/version", response_class=JSONResponse)
async def api_version():
"""
This function is used to get the version of the API
:return: JSONResponse
Returns the current version of the API.
This asynchronous function retrieves the version of the API from the
__version__ variable and returns it in a dictionary.
Returns:
dict: A dictionary containing the API version.
"""
return {"version": __version__}


@api_router.get("/logs", response_class=JSONResponse)
async def api_logs(job_id: str, token: bool = Depends(api_auth)):
"""
This function is used to get the logs of a project
:param job_id: str
:return: JSONResponse
Fetch logs for a specific job.
Args:
job_id (str): The ID of the job for which logs are to be fetched.
token (bool, optional): Authentication token, defaults to the result of api_auth dependency.
Returns:
dict: A dictionary containing the logs, success status, and a message.
"""
# project = AutoTrainProject(job_id=job_id, token=token)
# logs = project.get_logs()
Expand All @@ -630,9 +710,22 @@ async def api_logs(job_id: str, token: bool = Depends(api_auth)):
@api_router.get("/stop_training", response_class=JSONResponse)
async def api_stop_training(job_id: str, token: bool = Depends(api_auth)):
"""
This function is used to stop the training of a project
:param job_id: str
:return: JSONResponse
Stops the training job with the given job ID.
This asynchronous function pauses the training job identified by the provided job ID.
It uses the Hugging Face API to pause the space associated with the job.
Args:
job_id (str): The ID of the job to stop.
token (bool, optional): The authentication token, provided by dependency injection.
Returns:
dict: A dictionary containing a message and a success flag. If the training job
was successfully stopped, the message indicates success and the success flag is True.
If there was an error, the message contains the error details and the success flag is False.
Raises:
Exception: If there is an error while attempting to stop the training job.
"""
hf_api = HfApi(token=token)
try:
Expand Down
10 changes: 10 additions & 0 deletions src/autotrain/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@

@app.get("/")
async def forward_to_ui(request: Request):
"""
Forwards the incoming request to the UI endpoint.
Args:
request (Request): The incoming HTTP request.
Returns:
RedirectResponse: A response object that redirects to the UI endpoint,
including any query parameters from the original request.
"""
query_params = request.query_params
url = "/ui/"
if query_params:
Expand Down
30 changes: 30 additions & 0 deletions src/autotrain/app/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,36 @@


class AutoTrainDB:
"""
A class to manage job records in a SQLite database.
Attributes:
-----------
db_path : str
The path to the SQLite database file.
conn : sqlite3.Connection
The SQLite database connection object.
c : sqlite3.Cursor
The SQLite database cursor object.
Methods:
--------
__init__(db_path):
Initializes the database connection and creates the jobs table if it does not exist.
create_jobs_table():
Creates the jobs table in the database if it does not exist.
add_job(pid):
Adds a new job with the given process ID (pid) to the jobs table.
get_running_jobs():
Retrieves a list of all running job process IDs (pids) from the jobs table.
delete_job(pid):
Deletes the job with the given process ID (pid) from the jobs table.
"""

def __init__(self, db_path):
self.db_path = db_path
self.conn = sqlite3.connect(db_path)
Expand Down
21 changes: 21 additions & 0 deletions src/autotrain/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,34 @@


def get_sorted_models(hub_models):
"""
Filters and sorts a list of models based on their download count.
Args:
hub_models (list): A list of model objects. Each model object must have the attributes 'id', 'downloads', and 'private'.
Returns:
list: A list of model IDs sorted by their download count in descending order. Only includes models that are not private.
"""
hub_models = [{"id": m.id, "downloads": m.downloads} for m in hub_models if m.private is False]
hub_models = sorted(hub_models, key=lambda x: x["downloads"], reverse=True)
hub_models = [m["id"] for m in hub_models]
return hub_models


def _fetch_text_classification_models():
"""
Fetches and sorts text classification models from the Hugging Face model hub.
This function retrieves models for the tasks "fill-mask" and "text-classification"
from the Hugging Face model hub, sorts them by the number of downloads, and combines
them into a single list. Additionally, it fetches trending models based on the number
of likes in the past 7 days, sorts them, and places them at the beginning of the list
if they are not already included.
Returns:
list: A sorted list of model identifiers from the Hugging Face model hub.
"""
hub_models1 = list(
list_models(
task="fill-mask",
Expand Down
81 changes: 66 additions & 15 deletions src/autotrain/app/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@


def attach_oauth(app: fastapi.FastAPI):
"""
Attaches OAuth functionality to a FastAPI application by adding OAuth routes and session middleware.
Args:
app (fastapi.FastAPI): The FastAPI application instance to which OAuth routes and middleware will be attached.
Notes:
- The session middleware requires a secret key to sign the cookies. A hash of the OAuth secret key is used to
make it unique to the Space and to ensure it is updated if the OAuth configuration changes.
- The session secret includes a version identifier ("-autotrain-v2") to allow for future changes in the session
cookie format. If the format changes, the version can be bumped to invalidate old cookies and prevent HTTP 500 errors.
"""
_add_oauth_routes(app)
# Session Middleware requires a secret key to sign the cookies. Let's use a hash
# of the OAuth secret key to make it unique to the Space + updated in case OAuth
Expand All @@ -38,6 +50,23 @@ def attach_oauth(app: fastapi.FastAPI):


def _add_oauth_routes(app: fastapi.FastAPI) -> None:
"""
Add OAuth routes to the FastAPI app (login, callback handler, and logout).
This function performs the following tasks:
1. Checks for required environment variables and raises a ValueError if any are missing.
2. Registers the OAuth server with the provided client ID, client secret, scopes, and OpenID provider URL.
3. Defines the following OAuth routes:
- `/login/huggingface`: Redirects to the Hugging Face OAuth page.
- `/auth`: Handles the OAuth callback and manages the OAuth state.
Args:
app (fastapi.FastAPI): The FastAPI application instance to which the OAuth routes will be added.
Raises:
ValueError: If any of the required environment variables (OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET,
OAUTH_SCOPES, OPENID_PROVIDER_URL) are not set.
"""
"""Add OAuth routes to the FastAPI app (login, callback handler and logout)."""
# Check environment variables
msg = (
Expand Down Expand Up @@ -66,6 +95,15 @@ def _add_oauth_routes(app: fastapi.FastAPI) -> None:
# Define OAuth routes
@app.get("/login/huggingface")
async def oauth_login(request: fastapi.Request):
"""
Handles the OAuth login process by redirecting to the Hugging Face OAuth page.
Args:
request (fastapi.Request): The incoming HTTP request.
Returns:
Response: A redirection response to the Hugging Face OAuth authorization page.
"""
"""Endpoint that redirects to HF OAuth page."""
redirect_uri = request.url_for("auth")
redirect_uri_as_str = str(redirect_uri)
Expand All @@ -75,6 +113,25 @@ async def oauth_login(request: fastapi.Request):

@app.get("/auth")
async def auth(request: fastapi.Request) -> RedirectResponse:
"""
Handles the OAuth callback for Hugging Face authentication.
Args:
request (fastapi.Request): The incoming request object.
Returns:
RedirectResponse: A response object that redirects the user to the appropriate page.
Raises:
MismatchingStateError: If there is a state mismatch, likely due to a corrupted cookie.
In this case, the user is redirected to the login page after clearing the relevant session keys.
Notes:
- If the state mismatch occurs, it is likely due to a bug in authlib that causes the token to grow indefinitely
if the user tries to login repeatedly. Since cookies cannot exceed 4kb, the token will be truncated at some point,
resulting in a lost state. The workaround is to delete the cookie and redirect the user to the login page again.
- See https://github.com/lepture/authlib/issues/622 for more details.
"""
"""Endpoint that handles the OAuth callback."""
try:
oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore
Expand All @@ -99,23 +156,17 @@ async def auth(request: fastapi.Request) -> RedirectResponse:
return _redirect_to_target(request)


def _generate_redirect_uri(request: fastapi.Request) -> str:
if "_target_url" in request.query_params:
# if `_target_url` already in query params => respect it
target = request.query_params["_target_url"]
else:
# otherwise => keep query params
target = "/?" + urllib.parse.urlencode(request.query_params)

redirect_uri = request.url_for("oauth_redirect_callback").include_query_params(_target_url=target)
redirect_uri_as_str = str(redirect_uri)
if redirect_uri.netloc.endswith(".hf.space"):
# In Space, FastAPI redirect as http but we want https
redirect_uri_as_str = redirect_uri_as_str.replace("http://", "https://")
return redirect_uri_as_str
def _redirect_to_target(request: fastapi.Request, default_target: str = "/") -> RedirectResponse:
"""
Redirects the incoming request to a target URL specified in the query parameters.
Args:
request (fastapi.Request): The incoming HTTP request.
default_target (str, optional): The default URL to redirect to if no target URL is specified in the query parameters. Defaults to "/".
def _redirect_to_target(request: fastapi.Request, default_target: str = "/") -> RedirectResponse:
Returns:
RedirectResponse: A response object that redirects the client to the target URL.
"""
target = request.query_params.get("_target_url", default_target)
# target = "https://huggingface.co/spaces/" + os.environ.get("SPACE_ID")
return RedirectResponse(target)
Loading

0 comments on commit e5a3b02

Please sign in to comment.