Skip to content

Commit

Permalink
add validators
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed May 10, 2024
1 parent 3c39cae commit de6f953
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/autotrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@


logger = Logger().get_logger()
__version__ = "0.7.86.dev0"
__version__ = "0.7.88.dev0"
232 changes: 221 additions & 11 deletions src/autotrain/app/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import JSONResponse
from huggingface_hub import HfApi
from pydantic import BaseModel, create_model
from pydantic import BaseModel, create_model, model_validator

from autotrain import __version__, logger
from autotrain.app.params import HIDDEN_PARAMS, PARAMS, AppParams
Expand Down Expand Up @@ -86,6 +86,70 @@ def create_api_base_model(base_class, class_name):
TokenClassificationParamsAPI = create_api_base_model(TokenClassificationParams, "TokenClassificationParamsAPI")


class LLMSFTColumnMapping(BaseModel):
text_column: str


class LLMDPOColumnMapping(BaseModel):
text_column: str
rejected_text_column: str
prompt_text_column: str


class LLMORPOColumnMapping(BaseModel):
text_column: str
rejected_text_column: str
prompt_text_column: str


class LLMGenericColumnMapping(BaseModel):
text_column: str


class LLMRewardColumnMapping(BaseModel):
text_column: str
rejected_text_column: str


class DreamBoothColumnMapping(BaseModel):
default: Optional[str] = None


class ImageClassificationColumnMapping(BaseModel):
image_column: str
target_column: str


class Seq2SeqColumnMapping(BaseModel):
text_column: str
target_column: str


class TabularClassificationColumnMapping(BaseModel):
id_column: str
target_columns: List[str]


class TabularRegressionColumnMapping(BaseModel):
id_column: str
target_columns: List[str]


class TextClassificationColumnMapping(BaseModel):
text_column: str
target_column: str


class TextRegressionColumnMapping(BaseModel):
text_column: str
target_column: str


class TokenClassificationColumnMapping(BaseModel):
tokens_column: str
tags_column: str


class APICreateProjectModel(BaseModel):
project_name: str
task: Literal[
Expand Down Expand Up @@ -134,11 +198,162 @@ class APICreateProjectModel(BaseModel):
TokenClassificationParamsAPI,
]
username: str
column_mapping: Optional[Dict[str, Union[List[str], str]]] = None
column_mapping: Optional[
Union[
LLMSFTColumnMapping,
LLMDPOColumnMapping,
LLMORPOColumnMapping,
LLMGenericColumnMapping,
LLMRewardColumnMapping,
DreamBoothColumnMapping,
ImageClassificationColumnMapping,
Seq2SeqColumnMapping,
TabularClassificationColumnMapping,
TabularRegressionColumnMapping,
TextClassificationColumnMapping,
TextRegressionColumnMapping,
TokenClassificationColumnMapping,
]
] = None
hub_dataset: str
train_split: str
valid_split: Optional[str] = None

@model_validator(mode="before")
@classmethod
def validate_column_mapping(cls, values):
if values.get("task") == "llm:sft":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for llm:sft")
if not values.get("column_mapping").get("text_column"):
raise ValueError("text_column is required for llm:sft")
values["column_mapping"] = LLMSFTColumnMapping(**values["column_mapping"])
elif values.get("task") == "llm:dpo":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for llm:dpo")
if not values.get("column_mapping").get("text_column"):
raise ValueError("text_column is required for llm:dpo")
if not values.get("column_mapping").get("rejected_text_column"):
raise ValueError("rejected_text_column is required for llm:dpo")
if not values.get("column_mapping").get("prompt_text_column"):
raise ValueError("prompt_text_column is required for llm:dpo")
values["column_mapping"] = LLMDPOColumnMapping(**values["column_mapping"])
elif values.get("task") == "llm:orpo":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for llm:orpo")
if not values.get("column_mapping").get("text_column"):
raise ValueError("text_column is required for llm:orpo")
if not values.get("column_mapping").get("rejected_text_column"):
raise ValueError("rejected_text_column is required for llm:orpo")
if not values.get("column_mapping").get("prompt_text_column"):
raise ValueError("prompt_text_column is required for llm:orpo")
values["column_mapping"] = LLMORPOColumnMapping(**values["column_mapping"])
elif values.get("task") == "llm:generic":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for llm:generic")
if not values.get("column_mapping").get("text_column"):
raise ValueError("text_column is required for llm:generic")
values["column_mapping"] = LLMGenericColumnMapping(**values["column_mapping"])
elif values.get("task") == "llm:reward":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for llm:reward")
if not values.get("column_mapping").get("text_column"):
raise ValueError("text_column is required for llm:reward")
if not values.get("column_mapping").get("rejected_text_column"):
raise ValueError("rejected_text_column is required for llm:reward")
values["column_mapping"] = LLMRewardColumnMapping(**values["column_mapping"])
elif values.get("task") == "dreambooth":
if values.get("column_mapping"):
raise ValueError("column_mapping is not required for dreambooth")
elif values.get("task") == "seq2seq":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for seq2seq")
if not values.get("column_mapping").get("text_column"):
raise ValueError("text_column is required for seq2seq")
if not values.get("column_mapping").get("target_column"):
raise ValueError("target_column is required for seq2seq")
values["column_mapping"] = Seq2SeqColumnMapping(**values["column_mapping"])
elif values.get("task") == "image-classification":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for image-classification")
if not values.get("column_mapping").get("image_column"):
raise ValueError("image_column is required for image-classification")
if not values.get("column_mapping").get("target_column"):
raise ValueError("target_column is required for image-classification")
values["column_mapping"] = ImageClassificationColumnMapping(**values["column_mapping"])
elif values.get("task") == "tabular-classification":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for tabular-classification")
if not values.get("column_mapping").get("id_column"):
raise ValueError("id_column is required for tabular-classification")
if not values.get("column_mapping").get("target_columns"):
raise ValueError("target_columns is required for tabular-classification")
values["column_mapping"] = TabularClassificationColumnMapping(**values["column_mapping"])
elif values.get("task") == "tabular-regression":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for tabular-regression")
if not values.get("column_mapping").get("id_column"):
raise ValueError("id_column is required for tabular-regression")
if not values.get("column_mapping").get("target_columns"):
raise ValueError("target_columns is required for tabular-regression")
values["column_mapping"] = TabularRegressionColumnMapping(**values["column_mapping"])
elif values.get("task") == "text-classification":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for text-classification")
if not values.get("column_mapping").get("text_column"):
raise ValueError("text_column is required for text-classification")
if not values.get("column_mapping").get("target_column"):
raise ValueError("target_column is required for text-classification")
values["column_mapping"] = TextClassificationColumnMapping(**values["column_mapping"])
elif values.get("task") == "text-regression":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for text-regression")
if not values.get("column_mapping").get("text_column"):
raise ValueError("text_column is required for text-regression")
if not values.get("column_mapping").get("target_column"):
raise ValueError("target_column is required for text-regression")
values["column_mapping"] = TextRegressionColumnMapping(**values["column_mapping"])
elif values.get("task") == "token-classification":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for token-classification")
if not values.get("column_mapping").get("tokens_column"):
raise ValueError("tokens_column is required for token-classification")
if not values.get("column_mapping").get("tags_column"):
raise ValueError("tags_column is required for token-classification")
values["column_mapping"] = TokenClassificationColumnMapping(**values["column_mapping"])
return values

@model_validator(mode="before")
@classmethod
def validate_params(cls, values):
if values.get("task") == "llm:sft":
values["params"] = LLMSFTTrainingParamsAPI(**values["params"])
elif values.get("task") == "llm:dpo":
values["params"] = LLMDPOTrainingParamsAPI(**values["params"])
elif values.get("task") == "llm:orpo":
values["params"] = LLMORPOTrainingParamsAPI(**values["params"])
elif values.get("task") == "llm:generic":
values["params"] = LLMGenericTrainingParamsAPI(**values["params"])
elif values.get("task") == "llm:reward":
values["params"] = LLMRewardTrainingParamsAPI(**values["params"])
elif values.get("task") == "dreambooth":
values["params"] = DreamBoothTrainingParamsAPI(**values["params"])
elif values.get("task") == "seq2seq":
values["params"] = Seq2SeqParamsAPI(**values["params"])
elif values.get("task") == "image-classification":
values["params"] = ImageClassificationParamsAPI(**values["params"])
elif values.get("task") == "tabular-classification":
values["params"] = TabularClassificationParamsAPI(**values["params"])
elif values.get("task") == "tabular-regression":
values["params"] = TabularRegressionParamsAPI(**values["params"])
elif values.get("task") == "text-classification":
values["params"] = TextClassificationParamsAPI(**values["params"])
elif values.get("task") == "text-regression":
values["params"] = TextRegressionParamsAPI(**values["params"])
elif values.get("task") == "token-classification":
values["params"] = TokenClassificationParamsAPI(**values["params"])
return values


api_router = APIRouter()

Expand Down Expand Up @@ -171,19 +386,14 @@ async def api_create_project(project: APICreateProjectModel, token: bool = Depen
:param project: APICreateProjectModel
:return: JSONResponse
"""
provided_params = project.params.dict()
provided_params = project.params.model_dump()
if project.hardware == "local":
hardware = "local-ui" # local-ui has wait=False
else:
hardware = project.hardware

if project.column_mapping is not None:
for key, value in project.column_mapping.items():
provided_params[key] = value

provided_params.update({"data_path": project.hub_dataset})
provided_params.update({"train_split": project.train_split})
provided_params.update({"valid_split": project.valid_split})
logger.info(provided_params)
logger.info(project.column_mapping)

task = project.task
if task.startswith("llm"):
Expand All @@ -205,7 +415,7 @@ async def api_create_project(project: APICreateProjectModel, token: bool = Depen
task=task,
data_path=project.hub_dataset,
base_model=project.base_model,
column_mapping=project.column_mapping,
column_mapping=project.column_mapping.model_dump() if project.column_mapping else None,
using_hub_dataset=True,
train_split=project.train_split,
valid_split=project.valid_split,
Expand Down

0 comments on commit de6f953

Please sign in to comment.