Skip to content

Commit

Permalink
VLM: PaliGemma Finetuning (#712)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur authored Jul 25, 2024
1 parent 72870ab commit 752ad56
Show file tree
Hide file tree
Showing 21 changed files with 1,223 additions and 1 deletion.
30 changes: 30 additions & 0 deletions configs/vlm/paligemma_vqa.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
task: vlm:vqa
base_model: google/paligemma-3b-pt-224
project_name: autotrain-paligemma-finetuned-vqa
log: tensorboard
backend: local

data:
path: abhishek/vqa_small
train_split: train
valid_split: validation
column_mapping:
image_column: image
text_column: multiple_choice_answer
prompt_text_column: question

params:
epochs: 3
batch_size: 2
lr: 2e-5
optimizer: adamw_torch
scheduler: linear
gradient_accumulation: 4
mixed_precision: fp16
peft: true
quantization: int4

hub:
username: ${HF_USERNAME}
token: ${HF_TOKEN}
push_to_hub: true
38 changes: 38 additions & 0 deletions src/autotrain/app/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from autotrain.trainers.text_classification.params import TextClassificationParams
from autotrain.trainers.text_regression.params import TextRegressionParams
from autotrain.trainers.token_classification.params import TokenClassificationParams
from autotrain.trainers.vlm.params import VLMTrainingParams


FIELDS_TO_EXCLUDE = HIDDEN_PARAMS + ["push_to_hub"]
Expand Down Expand Up @@ -88,6 +89,7 @@ def create_api_base_model(base_class, class_name):
TokenClassificationParamsAPI = create_api_base_model(TokenClassificationParams, "TokenClassificationParamsAPI")
SentenceTransformersParamsAPI = create_api_base_model(SentenceTransformersParams, "SentenceTransformersParamsAPI")
ImageRegressionParamsAPI = create_api_base_model(ImageRegressionParams, "ImageRegressionParamsAPI")
VLMTrainingParamsAPI = create_api_base_model(VLMTrainingParams, "VLMTrainingParamsAPI")


class LLMSFTColumnMapping(BaseModel):
Expand Down Expand Up @@ -187,6 +189,12 @@ class STQAColumnMapping(BaseModel):
sentence2_column: str


class VLMColumnMapping(BaseModel):
image_column: str
text_column: str
prompt_text_column: str


class APICreateProjectModel(BaseModel):
project_name: str
task: Literal[
Expand All @@ -209,6 +217,8 @@ class APICreateProjectModel(BaseModel):
"tabular-classification",
"tabular-regression",
"image-regression",
"vlm:captioning",
"vlm:vqa",
]
base_model: str
hardware: Literal[
Expand Down Expand Up @@ -241,6 +251,7 @@ class APICreateProjectModel(BaseModel):
TextRegressionParamsAPI,
TokenClassificationParamsAPI,
ImageRegressionParamsAPI,
VLMTrainingParamsAPI,
]
username: str
column_mapping: Optional[
Expand All @@ -264,6 +275,7 @@ class APICreateProjectModel(BaseModel):
STTripletColumnMapping,
STQAColumnMapping,
ImageRegressionColumnMapping,
VLMColumnMapping,
]
] = None
hub_dataset: str
Expand Down Expand Up @@ -426,6 +438,26 @@ def validate_column_mapping(cls, values):
if not values.get("column_mapping").get("target_column"):
raise ValueError("target_column is required for image-regression")
values["column_mapping"] = ImageRegressionColumnMapping(**values["column_mapping"])
elif values.get("task") == "vlm:captioning":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for vlm:captioning")
if not values.get("column_mapping").get("image_column"):
raise ValueError("image_column is required for vlm:captioning")
if not values.get("column_mapping").get("text_column"):
raise ValueError("text_column is required for vlm:captioning")
if not values.get("column_mapping").get("prompt_text_column"):
raise ValueError("prompt_text_column is required for vlm:captioning")
values["column_mapping"] = VLMColumnMapping(**values["column_mapping"])
elif values.get("task") == "vlm:vqa":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for vlm:vqa")
if not values.get("column_mapping").get("image_column"):
raise ValueError("image_column is required for vlm:vqa")
if not values.get("column_mapping").get("text_column"):
raise ValueError("text_column is required for vlm:vqa")
if not values.get("column_mapping").get("prompt_text_column"):
raise ValueError("prompt_text_column is required for vlm:vqa")
values["column_mapping"] = VLMColumnMapping(**values["column_mapping"])
return values

@model_validator(mode="before")
Expand Down Expand Up @@ -461,6 +493,8 @@ def validate_params(cls, values):
values["params"] = SentenceTransformersParamsAPI(**values["params"])
elif values.get("task") == "image-regression":
values["params"] = ImageRegressionParamsAPI(**values["params"])
elif values.get("task").startswith("vlm:"):
values["params"] = VLMTrainingParamsAPI(**values["params"])
return values


Expand Down Expand Up @@ -513,6 +547,10 @@ async def api_create_project(project: APICreateProjectModel, token: bool = Depen
params = PARAMS["st"]
trainer = task.split(":")[1]
params.update({"trainer": trainer})
elif task.startswith("vlm:"):
params = PARAMS["vlm"]
trainer = task.split(":")[1]
params.update({"trainer": trainer})
elif task.startswith("tabular"):
params = PARAMS["tabular"]
else:
Expand Down
53 changes: 53 additions & 0 deletions src/autotrain/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,58 @@ def _fetch_st_models():
return hub_models


def _fetch_vlm_models():
hub_models1 = list(
list_models(
task="image-text-to-text",
sort="downloads",
direction=-1,
limit=100,
full=False,
filter=["paligemma"],
)
)
hub_models2 = list(
list_models(
task="image-text-to-text",
sort="downloads",
direction=-1,
limit=100,
full=False,
filter=["florence2"],
)
)
hub_models = list(hub_models1) + list(hub_models2)
hub_models = get_sorted_models(hub_models)

trending_models1 = list(
list_models(
task="image-text-to-text",
sort="likes7d",
direction=-1,
limit=30,
full=False,
filter=["paligemma"],
)
)
trending_models2 = list(
list_models(
task="image-text-to-text",
sort="likes7d",
direction=-1,
limit=30,
full=False,
filter=["florence2"],
)
)
trending_models = list(trending_models1) + list(trending_models2)
if len(trending_models) > 0:
trending_models = get_sorted_models(trending_models)
hub_models = [m for m in hub_models if m not in trending_models]
hub_models = trending_models + hub_models
return hub_models


def fetch_models():
_mc = collections.defaultdict(list)
_mc["text-classification"] = _fetch_text_classification_models()
Expand All @@ -323,6 +375,7 @@ def fetch_models():
_mc["text-regression"] = _fetch_text_classification_models()
_mc["image-object-detection"] = _fetch_image_object_detection_models()
_mc["sentence-transformers"] = _fetch_st_models()
_mc["vlm"] = _fetch_vlm_models()

# tabular-classification
_mc["tabular-classification"] = [
Expand Down
62 changes: 62 additions & 0 deletions src/autotrain/app/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from autotrain.trainers.text_classification.params import TextClassificationParams
from autotrain.trainers.text_regression.params import TextRegressionParams
from autotrain.trainers.token_classification.params import TokenClassificationParams
from autotrain.trainers.vlm.params import VLMTrainingParams


HIDDEN_PARAMS = [
Expand Down Expand Up @@ -131,6 +132,14 @@
mixed_precision="fp16",
log="tensorboard",
).model_dump()
PARAMS["vlm"] = VLMTrainingParams(
mixed_precision="fp16",
target_modules="all-linear",
log="tensorboard",
quantization="int4",
peft=True,
epochs=3,
).model_dump()


@dataclass
Expand Down Expand Up @@ -175,6 +184,8 @@ def munge(self):
return self._munge_params_sent_transformers()
elif self.task == "image-regression":
return self._munge_params_img_reg()
elif self.task.startswith("vlm"):
return self._munge_params_vlm()
else:
raise ValueError(f"Unknown task: {self.task}")

Expand Down Expand Up @@ -244,6 +255,35 @@ def _munge_params_llm(self):

return LLMTrainingParams(**_params)

def _munge_params_vlm(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
if not self.using_hub_dataset:
_params["text_column"] = "autotrain_text"
_params["prompt_text_column"] = "autotrain_prompt"
_params["image_column"] = "autotrain_image"
_params["valid_split"] = "validation"
else:
_params["text_column"] = self.column_mapping.get("text" if not self.api else "text_column", "text")
_params["prompt_text_column"] = self.column_mapping.get(
"prompt" if not self.api else "prompt_text_column", "prompt"
)
_params["image_column"] = self.column_mapping.get(
"image" if not self.api else "rejected_text_column", "image"
)
_params["train_split"] = self.train_split
_params["valid_split"] = self.valid_split
_params["log"] = "tensorboard"

trainer = self.task.split(":")[1]
_params["trainer"] = trainer.lower()

if "quantization" in _params:
if _params["quantization"] in ("none", "no"):
_params["quantization"] = None

return VLMTrainingParams(**_params)

def _munge_params_text_clf(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
Expand Down Expand Up @@ -409,6 +449,10 @@ def get_task_params(task, param_type):
trainer = task.split(":")[1].lower()
task = task.split(":")[0].lower()

if task.startswith("vlm:"):
trainer = task.split(":")[1].lower()
task = task.split(":")[0].lower()

if task.startswith("tabular"):
task = "tabular"

Expand Down Expand Up @@ -506,6 +550,24 @@ def get_task_params(task, param_type):
"early_stopping_threshold",
]
task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
if task == "vlm" and param_type == "basic":
more_hidden_params = [
"warmup_ratio",
"weight_decay",
"max_grad_norm",
"seed",
"logging_steps",
"auto_find_batch_size",
"save_total_limit",
"eval_strategy",
"early_stopping_patience",
"early_stopping_threshold",
"quantization",
"lora_r",
"lora_alpha",
"lora_dropout",
]
task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
if task == "text-regression" and param_type == "basic":
more_hidden_params = [
"warmup_ratio",
Expand Down
12 changes: 12 additions & 0 deletions src/autotrain/app/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@
fields = ['text', 'rejected_text'];
fieldNames = ['chosen', 'rejected'];
break;
case 'vlm:captioning':
fields = ['image', 'text'];
fieldNames = ['image', 'caption'];
break;
case 'vlm:vqa':
fields = ['image', 'prompt', 'text'];
fieldNames = ['image', 'question', 'answer'];
break;
case 'st:pair':
fields = ['sentence1', 'sentence2'];
fieldNames = ['anchor', 'positive'];
Expand Down Expand Up @@ -188,6 +196,10 @@
<option value="llm:dpo">LLM DPO</option>
<option value="llm:reward">LLM Reward</option>
</optgroup>
<optgroup label="VLM Finetuning">
<option value="vlm:captioning">VLM Captioning</option>
<option value="vlm:vqa">VLM VQA</option>
</optgroup>
<optgroup label="Sentence Transformers">
<option value="st:pair">ST Pair</option>
<option value="st:pair_class">ST Pair Classification</option>
Expand Down
15 changes: 14 additions & 1 deletion src/autotrain/app/ui_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AutoTrainImageClassificationDataset,
AutoTrainImageRegressionDataset,
AutoTrainObjectDetectionDataset,
AutoTrainVLMDataset,
)
from autotrain.help import get_app_help
from autotrain.project import AutoTrainProject
Expand Down Expand Up @@ -440,6 +441,8 @@ async def fetch_model_choices(
hub_models = MODEL_CHOICE["image-object-detection"]
elif task == "image-regression":
hub_models = MODEL_CHOICE["image-regression"]
elif task.startswith("vlm:"):
hub_models = MODEL_CHOICE["vlm"]
else:
raise NotImplementedError

Expand Down Expand Up @@ -571,7 +574,17 @@ async def handle_form(
username=autotrain_user,
local=hardware.lower() == "local-ui",
)

elif task.startswith("vlm:"):
dset = AutoTrainVLMDataset(
train_data=training_files[0],
token=token,
project_name=project_name,
username=autotrain_user,
column_mapping=column_mapping,
valid_data=validation_files[0] if validation_files else None,
percent_valid=None, # TODO: add to UI
local=hardware.lower() == "local-ui",
)
else:
if task.startswith("llm"):
dset_task = "lm_training"
Expand Down
4 changes: 4 additions & 0 deletions src/autotrain/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from autotrain.trainers.text_classification.params import TextClassificationParams
from autotrain.trainers.text_regression.params import TextRegressionParams
from autotrain.trainers.token_classification.params import TokenClassificationParams
from autotrain.trainers.vlm.params import VLMTrainingParams


AVAILABLE_HARDWARE = {
Expand Down Expand Up @@ -70,6 +71,7 @@ class BaseBackend:
ObjectDetectionParams,
SentenceTransformersParams,
ImageRegressionParams,
VLMTrainingParams,
]
backend: str

Expand Down Expand Up @@ -114,6 +116,8 @@ def __post_init__(self):
self.task_id = 30
elif isinstance(self.params, ImageRegressionParams):
self.task_id = 24
elif isinstance(self.params, VLMTrainingParams):
self.task_id = 31
else:
raise NotImplementedError

Expand Down
Loading

0 comments on commit 752ad56

Please sign in to comment.