Skip to content

Commit

Permalink
update client
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Nov 8, 2024
1 parent 5e9195c commit 8045435
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 27 deletions.
31 changes: 27 additions & 4 deletions src/autotrain/app/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
from autotrain.trainers.image_regression.params import ImageRegressionParams
from autotrain.trainers.object_detection.params import ObjectDetectionParams
from autotrain.trainers.sent_transformers.params import SentenceTransformersParams
from autotrain.trainers.seq2seq.params import Seq2SeqParams
from autotrain.trainers.tabular.params import TabularParams
Expand Down Expand Up @@ -112,6 +113,7 @@ def create_api_base_model(base_class, class_name):
ExtractiveQuestionAnsweringParamsAPI = create_api_base_model(
ExtractiveQuestionAnsweringParams, "ExtractiveQuestionAnsweringParamsAPI"
)
ObjectDetectionParamsAPI = create_api_base_model(ObjectDetectionParams, "ObjectDetectionParamsAPI")


class LLMSFTColumnMapping(BaseModel):
Expand Down Expand Up @@ -223,6 +225,11 @@ class ExtractiveQuestionAnsweringColumnMapping(BaseModel):
answer_column: str


class ObjectDetectionColumnMapping(BaseModel):
image_column: str
objects_column: str


class APICreateProjectModel(BaseModel):
"""
APICreateProjectModel is a Pydantic model that defines the schema for creating a project.
Expand Down Expand Up @@ -274,6 +281,7 @@ class APICreateProjectModel(BaseModel):
"vlm:captioning",
"vlm:vqa",
"extractive-question-answering",
"image-object-detection",
]
base_model: str
hardware: Literal[
Expand Down Expand Up @@ -311,6 +319,7 @@ class APICreateProjectModel(BaseModel):
ImageRegressionParamsAPI,
VLMTrainingParamsAPI,
ExtractiveQuestionAnsweringParamsAPI,
ObjectDetectionParamsAPI,
]
username: str
column_mapping: Optional[
Expand All @@ -336,6 +345,7 @@ class APICreateProjectModel(BaseModel):
ImageRegressionColumnMapping,
VLMColumnMapping,
ExtractiveQuestionAnsweringColumnMapping,
ObjectDetectionColumnMapping,
]
] = None
hub_dataset: str
Expand Down Expand Up @@ -528,6 +538,14 @@ def validate_column_mapping(cls, values):
if not values.get("column_mapping").get("answer_column"):
raise ValueError("answer_column is required for extractive-question-answering")
values["column_mapping"] = ExtractiveQuestionAnsweringColumnMapping(**values["column_mapping"])
elif values.get("task") == "image-object-detection":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for image-object-detection")
if not values.get("column_mapping").get("image_column"):
raise ValueError("image_column is required for image-object-detection")
if not values.get("column_mapping").get("objects_column"):
raise ValueError("objects_column is required for image-object-detection")
values["column_mapping"] = ObjectDetectionColumnMapping(**values["column_mapping"])
return values

@model_validator(mode="before")
Expand Down Expand Up @@ -567,6 +585,8 @@ def validate_params(cls, values):
values["params"] = VLMTrainingParamsAPI(**values["params"])
elif values.get("task") == "extractive-question-answering":
values["params"] = ExtractiveQuestionAnsweringParamsAPI(**values["params"])
elif values.get("task") == "image-object-detection":
values["params"] = ObjectDetectionParamsAPI(**values["params"])
return values


Expand Down Expand Up @@ -745,7 +765,7 @@ async def api_logs(job: JobIDModel, token: bool = Depends(api_auth)):
"""
job_id = job.jid
jwt_url = f"{constants.ENDPOINT}/api/spaces/{job_id}/jwt"
response = get_session().get(jwt_url, headers=build_hf_headers())
response = get_session().get(jwt_url, headers=build_hf_headers(token=token))
hf_raise_for_status(response)
jwt_token = response.json()["token"] # works for 24h (see "exp" field)

Expand All @@ -754,7 +774,9 @@ async def api_logs(job: JobIDModel, token: bool = Depends(api_auth)):

_logs = []
try:
with get_session().get(logs_url, headers=build_hf_headers(token=jwt_token), stream=True) as response:
with get_session().get(
logs_url, headers=build_hf_headers(token=jwt_token), stream=True, timeout=3
) as response:
hf_raise_for_status(response)
for line in response.iter_lines():
if not line.startswith(b"data: "):
Expand All @@ -766,9 +788,10 @@ async def api_logs(job: JobIDModel, token: bool = Depends(api_auth)):
continue # ignore (for example, empty lines or `b': keep-alive'`)
_logs.append((event["timestamp"], event["data"]))

# convert logs to a string
_logs = "\n".join([f"{timestamp}: {data}" for timestamp, data in _logs])

return {"logs": _logs, "success": True, "message": "Logs fetched successfully"}
except Exception as e:
if "Read timed out" in str(e):
_logs = "\n".join([f"{timestamp}: {data}" for timestamp, data in _logs])
return {"logs": _logs, "success": True, "message": "Logs fetched successfully"}
return {"logs": str(e), "success": False, "message": "Failed to fetch logs"}
76 changes: 53 additions & 23 deletions src/autotrain/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
"max_completion_length": 128,
"distributed_backend": "ddp",
"scheduler": "linear",
"merge_adapter": True,
"trainer": "sft",
}

PARAMS["text-classification"] = {
Expand Down Expand Up @@ -121,30 +123,58 @@
}

DEFAULT_COLUMN_MAPPING = {}
DEFAULT_COLUMN_MAPPING["llm:sft"] = {"text": "text"}
DEFAULT_COLUMN_MAPPING["llm:generic"] = {"text": "text"}
DEFAULT_COLUMN_MAPPING["llm:default"] = {"text": "text"}
DEFAULT_COLUMN_MAPPING["llm:dpo"] = {"prompt": "prompt", "text": "chosen", "rejected_text": "rejected"}
DEFAULT_COLUMN_MAPPING["llm:orpo"] = {"prompt": "prompt", "text": "chosen", "rejected_text": "rejected"}
DEFAULT_COLUMN_MAPPING["llm:reward"] = {"text": "chosen", "rejected_text": "rejected"}
DEFAULT_COLUMN_MAPPING["vlm:captioning"] = {"image": "image", "text": "caption"}
DEFAULT_COLUMN_MAPPING["vlm:vqa"] = {"image": "image", "prompt": "question", "text": "answer"}
DEFAULT_COLUMN_MAPPING["llm:sft"] = {"text_column": "text"}
DEFAULT_COLUMN_MAPPING["llm:generic"] = {"text_column": "text"}
DEFAULT_COLUMN_MAPPING["llm:default"] = {"text_column": "text"}
DEFAULT_COLUMN_MAPPING["llm:dpo"] = {
"prompt_column": "prompt",
"text_column": "chosen",
"rejected_text_column": "rejected",
}
DEFAULT_COLUMN_MAPPING["llm:orpo"] = {
"prompt_column": "prompt",
"text_column": "chosen",
"rejected_text_column": "rejected",
}
DEFAULT_COLUMN_MAPPING["llm:reward"] = {"text_column": "chosen", "rejected_text_column": "rejected"}
DEFAULT_COLUMN_MAPPING["vlm:captioning"] = {"image_column": "image", "text_column": "caption"}
DEFAULT_COLUMN_MAPPING["vlm:vqa"] = {
"image_column": "image",
"prompt_text_column": "question",
"text_column": "answer",
}
DEFAULT_COLUMN_MAPPING["st:pair"] = {"sentence1": "anchor", "sentence2": "positive"}
DEFAULT_COLUMN_MAPPING["st:pair_class"] = {"sentence1": "premise", "sentence2": "hypothesis", "target": "label"}
DEFAULT_COLUMN_MAPPING["st:pair_score"] = {"sentence1": "sentence1", "sentence2": "sentence2", "target": "score"}
DEFAULT_COLUMN_MAPPING["st:triplet"] = {"sentence1": "anchor", "sentence2": "positive", "sentence3": "negative"}
DEFAULT_COLUMN_MAPPING["st:qa"] = {"sentence1": "query", "sentence2": "answer"}
DEFAULT_COLUMN_MAPPING["text-classification"] = {"text": "text", "label": "target"}
DEFAULT_COLUMN_MAPPING["seq2seq"] = {"text": "text", "label": "target"}
DEFAULT_COLUMN_MAPPING["text-regression"] = {"text": "text", "label": "target"}
DEFAULT_COLUMN_MAPPING["token-classification"] = {"text": "tokens", "label": "tags"}
DEFAULT_COLUMN_MAPPING["dreambooth"] = {"image": "image"}
DEFAULT_COLUMN_MAPPING["image-classification"] = {"image": "image", "label": "label"}
DEFAULT_COLUMN_MAPPING["image-regression"] = {"image": "image", "label": "target"}
DEFAULT_COLUMN_MAPPING["image-object-detection"] = {"image": "image", "objects": "objects"}
DEFAULT_COLUMN_MAPPING["tabular:classification"] = {"id": "id", "label": "target"}
DEFAULT_COLUMN_MAPPING["tabular:regression"] = {"id": "id", "label": "target"}
DEFAULT_COLUMN_MAPPING["extractive-qa"] = {"text": "context", "question": "question", "answer": "answers"}
DEFAULT_COLUMN_MAPPING["st:pair_class"] = {
"sentence1_column": "premise",
"sentence2_column": "hypothesis",
"target_column": "label",
}
DEFAULT_COLUMN_MAPPING["st:pair_score"] = {
"sentence1_column": "sentence1",
"sentence2_column": "sentence2",
"target_column": "score",
}
DEFAULT_COLUMN_MAPPING["st:triplet"] = {
"sentence1_column": "anchor",
"sentence2_column": "positive",
"sentence3_column": "negative",
}
DEFAULT_COLUMN_MAPPING["st:qa"] = {"sentence1_column": "query", "sentence2_column": "answer"}
DEFAULT_COLUMN_MAPPING["text-classification"] = {"text_column": "text", "target_column": "target"}
DEFAULT_COLUMN_MAPPING["seq2seq"] = {"text_column": "text", "target_column": "target"}
DEFAULT_COLUMN_MAPPING["text-regression"] = {"text_column": "text", "target_column": "target"}
DEFAULT_COLUMN_MAPPING["token-classification"] = {"text_column": "tokens", "target_column": "tags"}
DEFAULT_COLUMN_MAPPING["dreambooth"] = {"default": "default"}
DEFAULT_COLUMN_MAPPING["image-classification"] = {"image_column": "image", "target_column": "label"}
DEFAULT_COLUMN_MAPPING["image-regression"] = {"image_column": "image", "target_column": "target"}
DEFAULT_COLUMN_MAPPING["image-object-detection"] = {"image_column": "image", "objects_column": "objects"}
DEFAULT_COLUMN_MAPPING["tabular:classification"] = {"id_column": "id", "target__columns": ["target"]}
DEFAULT_COLUMN_MAPPING["tabular:regression"] = {"id_column": "id", "target_columns": ["target"]}
DEFAULT_COLUMN_MAPPING["extractive-qa"] = {
"text_column": "context",
"question_column": "question",
"answer_column": "answers",
}

VALID_TASKS = [k for k in DEFAULT_COLUMN_MAPPING.keys()]

Expand Down

0 comments on commit 8045435

Please sign in to comment.