diff --git a/configs/vlm/paligemma_vqa.yml b/configs/vlm/paligemma_vqa.yml
new file mode 100644
index 0000000000..484888d2fd
--- /dev/null
+++ b/configs/vlm/paligemma_vqa.yml
@@ -0,0 +1,30 @@
+task: vlm:vqa
+base_model: google/paligemma-3b-pt-224
+project_name: autotrain-paligemma-finetuned-vqa
+log: tensorboard
+backend: local
+
+data:
+ path: abhishek/vqa_small
+ train_split: train
+ valid_split: validation
+ column_mapping:
+ image_column: image
+ text_column: multiple_choice_answer
+ prompt_text_column: question
+
+params:
+ epochs: 3
+ batch_size: 2
+ lr: 2e-5
+ optimizer: adamw_torch
+ scheduler: linear
+ gradient_accumulation: 4
+ mixed_precision: fp16
+ peft: true
+ quantization: int4
+
+hub:
+ username: ${HF_USERNAME}
+ token: ${HF_TOKEN}
+ push_to_hub: true
\ No newline at end of file
diff --git a/src/autotrain/app/api_routes.py b/src/autotrain/app/api_routes.py
index c1db5b0a89..38e56ddecb 100644
--- a/src/autotrain/app/api_routes.py
+++ b/src/autotrain/app/api_routes.py
@@ -20,6 +20,7 @@
from autotrain.trainers.text_classification.params import TextClassificationParams
from autotrain.trainers.text_regression.params import TextRegressionParams
from autotrain.trainers.token_classification.params import TokenClassificationParams
+from autotrain.trainers.vlm.params import VLMTrainingParams
FIELDS_TO_EXCLUDE = HIDDEN_PARAMS + ["push_to_hub"]
@@ -88,6 +89,7 @@ def create_api_base_model(base_class, class_name):
TokenClassificationParamsAPI = create_api_base_model(TokenClassificationParams, "TokenClassificationParamsAPI")
SentenceTransformersParamsAPI = create_api_base_model(SentenceTransformersParams, "SentenceTransformersParamsAPI")
ImageRegressionParamsAPI = create_api_base_model(ImageRegressionParams, "ImageRegressionParamsAPI")
+VLMTrainingParamsAPI = create_api_base_model(VLMTrainingParams, "VLMTrainingParamsAPI")
class LLMSFTColumnMapping(BaseModel):
@@ -187,6 +189,12 @@ class STQAColumnMapping(BaseModel):
sentence2_column: str
+class VLMColumnMapping(BaseModel):
+ image_column: str
+ text_column: str
+ prompt_text_column: str
+
+
class APICreateProjectModel(BaseModel):
project_name: str
task: Literal[
@@ -209,6 +217,8 @@ class APICreateProjectModel(BaseModel):
"tabular-classification",
"tabular-regression",
"image-regression",
+ "vlm:captioning",
+ "vlm:vqa",
]
base_model: str
hardware: Literal[
@@ -241,6 +251,7 @@ class APICreateProjectModel(BaseModel):
TextRegressionParamsAPI,
TokenClassificationParamsAPI,
ImageRegressionParamsAPI,
+ VLMTrainingParamsAPI,
]
username: str
column_mapping: Optional[
@@ -264,6 +275,7 @@ class APICreateProjectModel(BaseModel):
STTripletColumnMapping,
STQAColumnMapping,
ImageRegressionColumnMapping,
+ VLMColumnMapping,
]
] = None
hub_dataset: str
@@ -426,6 +438,26 @@ def validate_column_mapping(cls, values):
if not values.get("column_mapping").get("target_column"):
raise ValueError("target_column is required for image-regression")
values["column_mapping"] = ImageRegressionColumnMapping(**values["column_mapping"])
+ elif values.get("task") == "vlm:captioning":
+ if not values.get("column_mapping"):
+ raise ValueError("column_mapping is required for vlm:captioning")
+ if not values.get("column_mapping").get("image_column"):
+ raise ValueError("image_column is required for vlm:captioning")
+ if not values.get("column_mapping").get("text_column"):
+ raise ValueError("text_column is required for vlm:captioning")
+ if not values.get("column_mapping").get("prompt_text_column"):
+ raise ValueError("prompt_text_column is required for vlm:captioning")
+ values["column_mapping"] = VLMColumnMapping(**values["column_mapping"])
+ elif values.get("task") == "vlm:vqa":
+ if not values.get("column_mapping"):
+ raise ValueError("column_mapping is required for vlm:vqa")
+ if not values.get("column_mapping").get("image_column"):
+ raise ValueError("image_column is required for vlm:vqa")
+ if not values.get("column_mapping").get("text_column"):
+ raise ValueError("text_column is required for vlm:vqa")
+ if not values.get("column_mapping").get("prompt_text_column"):
+ raise ValueError("prompt_text_column is required for vlm:vqa")
+ values["column_mapping"] = VLMColumnMapping(**values["column_mapping"])
return values
@model_validator(mode="before")
@@ -461,6 +493,8 @@ def validate_params(cls, values):
values["params"] = SentenceTransformersParamsAPI(**values["params"])
elif values.get("task") == "image-regression":
values["params"] = ImageRegressionParamsAPI(**values["params"])
+ elif values.get("task").startswith("vlm:"):
+ values["params"] = VLMTrainingParamsAPI(**values["params"])
return values
@@ -513,6 +547,10 @@ async def api_create_project(project: APICreateProjectModel, token: bool = Depen
params = PARAMS["st"]
trainer = task.split(":")[1]
params.update({"trainer": trainer})
+ elif task.startswith("vlm:"):
+ params = PARAMS["vlm"]
+ trainer = task.split(":")[1]
+ params.update({"trainer": trainer})
elif task.startswith("tabular"):
params = PARAMS["tabular"]
else:
diff --git a/src/autotrain/app/models.py b/src/autotrain/app/models.py
index 0ea991a9f4..8b9e5bc86d 100644
--- a/src/autotrain/app/models.py
+++ b/src/autotrain/app/models.py
@@ -311,6 +311,58 @@ def _fetch_st_models():
return hub_models
+def _fetch_vlm_models():
+ hub_models1 = list(
+ list_models(
+ task="image-text-to-text",
+ sort="downloads",
+ direction=-1,
+ limit=100,
+ full=False,
+ filter=["paligemma"],
+ )
+ )
+ hub_models2 = list(
+ list_models(
+ task="image-text-to-text",
+ sort="downloads",
+ direction=-1,
+ limit=100,
+ full=False,
+ filter=["florence2"],
+ )
+ )
+ hub_models = list(hub_models1) + list(hub_models2)
+ hub_models = get_sorted_models(hub_models)
+
+ trending_models1 = list(
+ list_models(
+ task="image-text-to-text",
+ sort="likes7d",
+ direction=-1,
+ limit=30,
+ full=False,
+ filter=["paligemma"],
+ )
+ )
+ trending_models2 = list(
+ list_models(
+ task="image-text-to-text",
+ sort="likes7d",
+ direction=-1,
+ limit=30,
+ full=False,
+ filter=["florence2"],
+ )
+ )
+ trending_models = list(trending_models1) + list(trending_models2)
+ if len(trending_models) > 0:
+ trending_models = get_sorted_models(trending_models)
+ hub_models = [m for m in hub_models if m not in trending_models]
+ hub_models = trending_models + hub_models
+ return hub_models
+
+
def fetch_models():
_mc = collections.defaultdict(list)
_mc["text-classification"] = _fetch_text_classification_models()
@@ -323,6 +375,7 @@ def fetch_models():
_mc["text-regression"] = _fetch_text_classification_models()
_mc["image-object-detection"] = _fetch_image_object_detection_models()
_mc["sentence-transformers"] = _fetch_st_models()
+ _mc["vlm"] = _fetch_vlm_models()
# tabular-classification
_mc["tabular-classification"] = [
diff --git a/src/autotrain/app/params.py b/src/autotrain/app/params.py
index 08abcec7b3..5071006398 100644
--- a/src/autotrain/app/params.py
+++ b/src/autotrain/app/params.py
@@ -13,6 +13,7 @@
from autotrain.trainers.text_classification.params import TextClassificationParams
from autotrain.trainers.text_regression.params import TextRegressionParams
from autotrain.trainers.token_classification.params import TokenClassificationParams
+from autotrain.trainers.vlm.params import VLMTrainingParams
HIDDEN_PARAMS = [
@@ -131,6 +132,14 @@
mixed_precision="fp16",
log="tensorboard",
).model_dump()
+PARAMS["vlm"] = VLMTrainingParams(
+ mixed_precision="fp16",
+ target_modules="all-linear",
+ log="tensorboard",
+ quantization="int4",
+ peft=True,
+ epochs=3,
+).model_dump()
@dataclass
@@ -175,6 +184,8 @@ def munge(self):
return self._munge_params_sent_transformers()
elif self.task == "image-regression":
return self._munge_params_img_reg()
+ elif self.task.startswith("vlm"):
+ return self._munge_params_vlm()
else:
raise ValueError(f"Unknown task: {self.task}")
@@ -244,6 +255,35 @@ def _munge_params_llm(self):
return LLMTrainingParams(**_params)
+ def _munge_params_vlm(self):
+ _params = self._munge_common_params()
+ _params["model"] = self.base_model
+ if not self.using_hub_dataset:
+ _params["text_column"] = "autotrain_text"
+ _params["prompt_text_column"] = "autotrain_prompt"
+ _params["image_column"] = "autotrain_image"
+ _params["valid_split"] = "validation"
+ else:
+ _params["text_column"] = self.column_mapping.get("text" if not self.api else "text_column", "text")
+ _params["prompt_text_column"] = self.column_mapping.get(
+ "prompt" if not self.api else "prompt_text_column", "prompt"
+ )
+ _params["image_column"] = self.column_mapping.get(
+ "image" if not self.api else "rejected_text_column", "image"
+ )
+ _params["train_split"] = self.train_split
+ _params["valid_split"] = self.valid_split
+ _params["log"] = "tensorboard"
+
+ trainer = self.task.split(":")[1]
+ _params["trainer"] = trainer.lower()
+
+ if "quantization" in _params:
+ if _params["quantization"] in ("none", "no"):
+ _params["quantization"] = None
+
+ return VLMTrainingParams(**_params)
+
def _munge_params_text_clf(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
@@ -409,6 +449,10 @@ def get_task_params(task, param_type):
trainer = task.split(":")[1].lower()
task = task.split(":")[0].lower()
+ if task.startswith("vlm:"):
+ trainer = task.split(":")[1].lower()
+ task = task.split(":")[0].lower()
+
if task.startswith("tabular"):
task = "tabular"
@@ -506,6 +550,24 @@ def get_task_params(task, param_type):
"early_stopping_threshold",
]
task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
+ if task == "vlm" and param_type == "basic":
+ more_hidden_params = [
+ "warmup_ratio",
+ "weight_decay",
+ "max_grad_norm",
+ "seed",
+ "logging_steps",
+ "auto_find_batch_size",
+ "save_total_limit",
+ "eval_strategy",
+ "early_stopping_patience",
+ "early_stopping_threshold",
+ "quantization",
+ "lora_r",
+ "lora_alpha",
+ "lora_dropout",
+ ]
+ task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
if task == "text-regression" and param_type == "basic":
more_hidden_params = [
"warmup_ratio",
diff --git a/src/autotrain/app/templates/index.html b/src/autotrain/app/templates/index.html
index 9e40ff4357..75df30d244 100644
--- a/src/autotrain/app/templates/index.html
+++ b/src/autotrain/app/templates/index.html
@@ -38,6 +38,14 @@
fields = ['text', 'rejected_text'];
fieldNames = ['chosen', 'rejected'];
break;
+ case 'vlm:captioning':
+ fields = ['image', 'text'];
+ fieldNames = ['image', 'caption'];
+ break;
+ case 'vlm:vqa':
+ fields = ['image', 'prompt', 'text'];
+ fieldNames = ['image', 'question', 'answer'];
+ break;
case 'st:pair':
fields = ['sentence1', 'sentence2'];
fieldNames = ['anchor', 'positive'];
@@ -188,6 +196,10 @@
+