From 80454354c204d681bc11e85c7a16a3ef903aaf0f Mon Sep 17 00:00:00 2001 From: abhishekkrthakur Date: Fri, 8 Nov 2024 10:23:12 +0100 Subject: [PATCH] update client --- src/autotrain/app/api_routes.py | 31 ++++++++++++-- src/autotrain/client.py | 76 +++++++++++++++++++++++---------- 2 files changed, 80 insertions(+), 27 deletions(-) diff --git a/src/autotrain/app/api_routes.py b/src/autotrain/app/api_routes.py index bc311d60c6..5b6031692d 100644 --- a/src/autotrain/app/api_routes.py +++ b/src/autotrain/app/api_routes.py @@ -16,6 +16,7 @@ from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams from autotrain.trainers.image_classification.params import ImageClassificationParams from autotrain.trainers.image_regression.params import ImageRegressionParams +from autotrain.trainers.object_detection.params import ObjectDetectionParams from autotrain.trainers.sent_transformers.params import SentenceTransformersParams from autotrain.trainers.seq2seq.params import Seq2SeqParams from autotrain.trainers.tabular.params import TabularParams @@ -112,6 +113,7 @@ def create_api_base_model(base_class, class_name): ExtractiveQuestionAnsweringParamsAPI = create_api_base_model( ExtractiveQuestionAnsweringParams, "ExtractiveQuestionAnsweringParamsAPI" ) +ObjectDetectionParamsAPI = create_api_base_model(ObjectDetectionParams, "ObjectDetectionParamsAPI") class LLMSFTColumnMapping(BaseModel): @@ -223,6 +225,11 @@ class ExtractiveQuestionAnsweringColumnMapping(BaseModel): answer_column: str +class ObjectDetectionColumnMapping(BaseModel): + image_column: str + objects_column: str + + class APICreateProjectModel(BaseModel): """ APICreateProjectModel is a Pydantic model that defines the schema for creating a project. @@ -274,6 +281,7 @@ class APICreateProjectModel(BaseModel): "vlm:captioning", "vlm:vqa", "extractive-question-answering", + "image-object-detection", ] base_model: str hardware: Literal[ @@ -311,6 +319,7 @@ class APICreateProjectModel(BaseModel): ImageRegressionParamsAPI, VLMTrainingParamsAPI, ExtractiveQuestionAnsweringParamsAPI, + ObjectDetectionParamsAPI, ] username: str column_mapping: Optional[ @@ -336,6 +345,7 @@ class APICreateProjectModel(BaseModel): ImageRegressionColumnMapping, VLMColumnMapping, ExtractiveQuestionAnsweringColumnMapping, + ObjectDetectionColumnMapping, ] ] = None hub_dataset: str @@ -528,6 +538,14 @@ def validate_column_mapping(cls, values): if not values.get("column_mapping").get("answer_column"): raise ValueError("answer_column is required for extractive-question-answering") values["column_mapping"] = ExtractiveQuestionAnsweringColumnMapping(**values["column_mapping"]) + elif values.get("task") == "image-object-detection": + if not values.get("column_mapping"): + raise ValueError("column_mapping is required for image-object-detection") + if not values.get("column_mapping").get("image_column"): + raise ValueError("image_column is required for image-object-detection") + if not values.get("column_mapping").get("objects_column"): + raise ValueError("objects_column is required for image-object-detection") + values["column_mapping"] = ObjectDetectionColumnMapping(**values["column_mapping"]) return values @model_validator(mode="before") @@ -567,6 +585,8 @@ def validate_params(cls, values): values["params"] = VLMTrainingParamsAPI(**values["params"]) elif values.get("task") == "extractive-question-answering": values["params"] = ExtractiveQuestionAnsweringParamsAPI(**values["params"]) + elif values.get("task") == "image-object-detection": + values["params"] = ObjectDetectionParamsAPI(**values["params"]) return values @@ -745,7 +765,7 @@ async def api_logs(job: JobIDModel, token: bool = Depends(api_auth)): """ job_id = job.jid jwt_url = f"{constants.ENDPOINT}/api/spaces/{job_id}/jwt" - response = get_session().get(jwt_url, headers=build_hf_headers()) + response = get_session().get(jwt_url, headers=build_hf_headers(token=token)) hf_raise_for_status(response) jwt_token = response.json()["token"] # works for 24h (see "exp" field) @@ -754,7 +774,9 @@ async def api_logs(job: JobIDModel, token: bool = Depends(api_auth)): _logs = [] try: - with get_session().get(logs_url, headers=build_hf_headers(token=jwt_token), stream=True) as response: + with get_session().get( + logs_url, headers=build_hf_headers(token=jwt_token), stream=True, timeout=3 + ) as response: hf_raise_for_status(response) for line in response.iter_lines(): if not line.startswith(b"data: "): @@ -766,9 +788,10 @@ async def api_logs(job: JobIDModel, token: bool = Depends(api_auth)): continue # ignore (for example, empty lines or `b': keep-alive'`) _logs.append((event["timestamp"], event["data"])) - # convert logs to a string _logs = "\n".join([f"{timestamp}: {data}" for timestamp, data in _logs]) - return {"logs": _logs, "success": True, "message": "Logs fetched successfully"} except Exception as e: + if "Read timed out" in str(e): + _logs = "\n".join([f"{timestamp}: {data}" for timestamp, data in _logs]) + return {"logs": _logs, "success": True, "message": "Logs fetched successfully"} return {"logs": str(e), "success": False, "message": "Failed to fetch logs"} diff --git a/src/autotrain/client.py b/src/autotrain/client.py index e4613735b6..9646462c27 100644 --- a/src/autotrain/client.py +++ b/src/autotrain/client.py @@ -41,6 +41,8 @@ "max_completion_length": 128, "distributed_backend": "ddp", "scheduler": "linear", + "merge_adapter": True, + "trainer": "sft", } PARAMS["text-classification"] = { @@ -121,30 +123,58 @@ } DEFAULT_COLUMN_MAPPING = {} -DEFAULT_COLUMN_MAPPING["llm:sft"] = {"text": "text"} -DEFAULT_COLUMN_MAPPING["llm:generic"] = {"text": "text"} -DEFAULT_COLUMN_MAPPING["llm:default"] = {"text": "text"} -DEFAULT_COLUMN_MAPPING["llm:dpo"] = {"prompt": "prompt", "text": "chosen", "rejected_text": "rejected"} -DEFAULT_COLUMN_MAPPING["llm:orpo"] = {"prompt": "prompt", "text": "chosen", "rejected_text": "rejected"} -DEFAULT_COLUMN_MAPPING["llm:reward"] = {"text": "chosen", "rejected_text": "rejected"} -DEFAULT_COLUMN_MAPPING["vlm:captioning"] = {"image": "image", "text": "caption"} -DEFAULT_COLUMN_MAPPING["vlm:vqa"] = {"image": "image", "prompt": "question", "text": "answer"} +DEFAULT_COLUMN_MAPPING["llm:sft"] = {"text_column": "text"} +DEFAULT_COLUMN_MAPPING["llm:generic"] = {"text_column": "text"} +DEFAULT_COLUMN_MAPPING["llm:default"] = {"text_column": "text"} +DEFAULT_COLUMN_MAPPING["llm:dpo"] = { + "prompt_column": "prompt", + "text_column": "chosen", + "rejected_text_column": "rejected", +} +DEFAULT_COLUMN_MAPPING["llm:orpo"] = { + "prompt_column": "prompt", + "text_column": "chosen", + "rejected_text_column": "rejected", +} +DEFAULT_COLUMN_MAPPING["llm:reward"] = {"text_column": "chosen", "rejected_text_column": "rejected"} +DEFAULT_COLUMN_MAPPING["vlm:captioning"] = {"image_column": "image", "text_column": "caption"} +DEFAULT_COLUMN_MAPPING["vlm:vqa"] = { + "image_column": "image", + "prompt_text_column": "question", + "text_column": "answer", +} DEFAULT_COLUMN_MAPPING["st:pair"] = {"sentence1": "anchor", "sentence2": "positive"} -DEFAULT_COLUMN_MAPPING["st:pair_class"] = {"sentence1": "premise", "sentence2": "hypothesis", "target": "label"} -DEFAULT_COLUMN_MAPPING["st:pair_score"] = {"sentence1": "sentence1", "sentence2": "sentence2", "target": "score"} -DEFAULT_COLUMN_MAPPING["st:triplet"] = {"sentence1": "anchor", "sentence2": "positive", "sentence3": "negative"} -DEFAULT_COLUMN_MAPPING["st:qa"] = {"sentence1": "query", "sentence2": "answer"} -DEFAULT_COLUMN_MAPPING["text-classification"] = {"text": "text", "label": "target"} -DEFAULT_COLUMN_MAPPING["seq2seq"] = {"text": "text", "label": "target"} -DEFAULT_COLUMN_MAPPING["text-regression"] = {"text": "text", "label": "target"} -DEFAULT_COLUMN_MAPPING["token-classification"] = {"text": "tokens", "label": "tags"} -DEFAULT_COLUMN_MAPPING["dreambooth"] = {"image": "image"} -DEFAULT_COLUMN_MAPPING["image-classification"] = {"image": "image", "label": "label"} -DEFAULT_COLUMN_MAPPING["image-regression"] = {"image": "image", "label": "target"} -DEFAULT_COLUMN_MAPPING["image-object-detection"] = {"image": "image", "objects": "objects"} -DEFAULT_COLUMN_MAPPING["tabular:classification"] = {"id": "id", "label": "target"} -DEFAULT_COLUMN_MAPPING["tabular:regression"] = {"id": "id", "label": "target"} -DEFAULT_COLUMN_MAPPING["extractive-qa"] = {"text": "context", "question": "question", "answer": "answers"} +DEFAULT_COLUMN_MAPPING["st:pair_class"] = { + "sentence1_column": "premise", + "sentence2_column": "hypothesis", + "target_column": "label", +} +DEFAULT_COLUMN_MAPPING["st:pair_score"] = { + "sentence1_column": "sentence1", + "sentence2_column": "sentence2", + "target_column": "score", +} +DEFAULT_COLUMN_MAPPING["st:triplet"] = { + "sentence1_column": "anchor", + "sentence2_column": "positive", + "sentence3_column": "negative", +} +DEFAULT_COLUMN_MAPPING["st:qa"] = {"sentence1_column": "query", "sentence2_column": "answer"} +DEFAULT_COLUMN_MAPPING["text-classification"] = {"text_column": "text", "target_column": "target"} +DEFAULT_COLUMN_MAPPING["seq2seq"] = {"text_column": "text", "target_column": "target"} +DEFAULT_COLUMN_MAPPING["text-regression"] = {"text_column": "text", "target_column": "target"} +DEFAULT_COLUMN_MAPPING["token-classification"] = {"text_column": "tokens", "target_column": "tags"} +DEFAULT_COLUMN_MAPPING["dreambooth"] = {"default": "default"} +DEFAULT_COLUMN_MAPPING["image-classification"] = {"image_column": "image", "target_column": "label"} +DEFAULT_COLUMN_MAPPING["image-regression"] = {"image_column": "image", "target_column": "target"} +DEFAULT_COLUMN_MAPPING["image-object-detection"] = {"image_column": "image", "objects_column": "objects"} +DEFAULT_COLUMN_MAPPING["tabular:classification"] = {"id_column": "id", "target__columns": ["target"]} +DEFAULT_COLUMN_MAPPING["tabular:regression"] = {"id_column": "id", "target_columns": ["target"]} +DEFAULT_COLUMN_MAPPING["extractive-qa"] = { + "text_column": "context", + "question_column": "question", + "answer_column": "answers", +} VALID_TASKS = [k for k in DEFAULT_COLUMN_MAPPING.keys()]