diff --git a/configs/vlm/paligemma_vqa.yml b/configs/vlm/paligemma_vqa.yml new file mode 100644 index 0000000000..484888d2fd --- /dev/null +++ b/configs/vlm/paligemma_vqa.yml @@ -0,0 +1,30 @@ +task: vlm:vqa +base_model: google/paligemma-3b-pt-224 +project_name: autotrain-paligemma-finetuned-vqa +log: tensorboard +backend: local + +data: + path: abhishek/vqa_small + train_split: train + valid_split: validation + column_mapping: + image_column: image + text_column: multiple_choice_answer + prompt_text_column: question + +params: + epochs: 3 + batch_size: 2 + lr: 2e-5 + optimizer: adamw_torch + scheduler: linear + gradient_accumulation: 4 + mixed_precision: fp16 + peft: true + quantization: int4 + +hub: + username: ${HF_USERNAME} + token: ${HF_TOKEN} + push_to_hub: true \ No newline at end of file diff --git a/src/autotrain/app/api_routes.py b/src/autotrain/app/api_routes.py index c1db5b0a89..38e56ddecb 100644 --- a/src/autotrain/app/api_routes.py +++ b/src/autotrain/app/api_routes.py @@ -20,6 +20,7 @@ from autotrain.trainers.text_classification.params import TextClassificationParams from autotrain.trainers.text_regression.params import TextRegressionParams from autotrain.trainers.token_classification.params import TokenClassificationParams +from autotrain.trainers.vlm.params import VLMTrainingParams FIELDS_TO_EXCLUDE = HIDDEN_PARAMS + ["push_to_hub"] @@ -88,6 +89,7 @@ def create_api_base_model(base_class, class_name): TokenClassificationParamsAPI = create_api_base_model(TokenClassificationParams, "TokenClassificationParamsAPI") SentenceTransformersParamsAPI = create_api_base_model(SentenceTransformersParams, "SentenceTransformersParamsAPI") ImageRegressionParamsAPI = create_api_base_model(ImageRegressionParams, "ImageRegressionParamsAPI") +VLMTrainingParamsAPI = create_api_base_model(VLMTrainingParams, "VLMTrainingParamsAPI") class LLMSFTColumnMapping(BaseModel): @@ -187,6 +189,12 @@ class STQAColumnMapping(BaseModel): sentence2_column: str +class VLMColumnMapping(BaseModel): + image_column: str + text_column: str + prompt_text_column: str + + class APICreateProjectModel(BaseModel): project_name: str task: Literal[ @@ -209,6 +217,8 @@ class APICreateProjectModel(BaseModel): "tabular-classification", "tabular-regression", "image-regression", + "vlm:captioning", + "vlm:vqa", ] base_model: str hardware: Literal[ @@ -241,6 +251,7 @@ class APICreateProjectModel(BaseModel): TextRegressionParamsAPI, TokenClassificationParamsAPI, ImageRegressionParamsAPI, + VLMTrainingParamsAPI, ] username: str column_mapping: Optional[ @@ -264,6 +275,7 @@ class APICreateProjectModel(BaseModel): STTripletColumnMapping, STQAColumnMapping, ImageRegressionColumnMapping, + VLMColumnMapping, ] ] = None hub_dataset: str @@ -426,6 +438,26 @@ def validate_column_mapping(cls, values): if not values.get("column_mapping").get("target_column"): raise ValueError("target_column is required for image-regression") values["column_mapping"] = ImageRegressionColumnMapping(**values["column_mapping"]) + elif values.get("task") == "vlm:captioning": + if not values.get("column_mapping"): + raise ValueError("column_mapping is required for vlm:captioning") + if not values.get("column_mapping").get("image_column"): + raise ValueError("image_column is required for vlm:captioning") + if not values.get("column_mapping").get("text_column"): + raise ValueError("text_column is required for vlm:captioning") + if not values.get("column_mapping").get("prompt_text_column"): + raise ValueError("prompt_text_column is required for vlm:captioning") + values["column_mapping"] = VLMColumnMapping(**values["column_mapping"]) + elif values.get("task") == "vlm:vqa": + if not values.get("column_mapping"): + raise ValueError("column_mapping is required for vlm:vqa") + if not values.get("column_mapping").get("image_column"): + raise ValueError("image_column is required for vlm:vqa") + if not values.get("column_mapping").get("text_column"): + raise ValueError("text_column is required for vlm:vqa") + if not values.get("column_mapping").get("prompt_text_column"): + raise ValueError("prompt_text_column is required for vlm:vqa") + values["column_mapping"] = VLMColumnMapping(**values["column_mapping"]) return values @model_validator(mode="before") @@ -461,6 +493,8 @@ def validate_params(cls, values): values["params"] = SentenceTransformersParamsAPI(**values["params"]) elif values.get("task") == "image-regression": values["params"] = ImageRegressionParamsAPI(**values["params"]) + elif values.get("task").startswith("vlm:"): + values["params"] = VLMTrainingParamsAPI(**values["params"]) return values @@ -513,6 +547,10 @@ async def api_create_project(project: APICreateProjectModel, token: bool = Depen params = PARAMS["st"] trainer = task.split(":")[1] params.update({"trainer": trainer}) + elif task.startswith("vlm:"): + params = PARAMS["vlm"] + trainer = task.split(":")[1] + params.update({"trainer": trainer}) elif task.startswith("tabular"): params = PARAMS["tabular"] else: diff --git a/src/autotrain/app/models.py b/src/autotrain/app/models.py index 0ea991a9f4..8b9e5bc86d 100644 --- a/src/autotrain/app/models.py +++ b/src/autotrain/app/models.py @@ -311,6 +311,58 @@ def _fetch_st_models(): return hub_models +def _fetch_vlm_models(): + hub_models1 = list( + list_models( + task="image-text-to-text", + sort="downloads", + direction=-1, + limit=100, + full=False, + filter=["paligemma"], + ) + ) + hub_models2 = list( + list_models( + task="image-text-to-text", + sort="downloads", + direction=-1, + limit=100, + full=False, + filter=["florence2"], + ) + ) + hub_models = list(hub_models1) + list(hub_models2) + hub_models = get_sorted_models(hub_models) + + trending_models1 = list( + list_models( + task="image-text-to-text", + sort="likes7d", + direction=-1, + limit=30, + full=False, + filter=["paligemma"], + ) + ) + trending_models2 = list( + list_models( + task="image-text-to-text", + sort="likes7d", + direction=-1, + limit=30, + full=False, + filter=["florence2"], + ) + ) + trending_models = list(trending_models1) + list(trending_models2) + if len(trending_models) > 0: + trending_models = get_sorted_models(trending_models) + hub_models = [m for m in hub_models if m not in trending_models] + hub_models = trending_models + hub_models + return hub_models + + def fetch_models(): _mc = collections.defaultdict(list) _mc["text-classification"] = _fetch_text_classification_models() @@ -323,6 +375,7 @@ def fetch_models(): _mc["text-regression"] = _fetch_text_classification_models() _mc["image-object-detection"] = _fetch_image_object_detection_models() _mc["sentence-transformers"] = _fetch_st_models() + _mc["vlm"] = _fetch_vlm_models() # tabular-classification _mc["tabular-classification"] = [ diff --git a/src/autotrain/app/params.py b/src/autotrain/app/params.py index 08abcec7b3..5071006398 100644 --- a/src/autotrain/app/params.py +++ b/src/autotrain/app/params.py @@ -13,6 +13,7 @@ from autotrain.trainers.text_classification.params import TextClassificationParams from autotrain.trainers.text_regression.params import TextRegressionParams from autotrain.trainers.token_classification.params import TokenClassificationParams +from autotrain.trainers.vlm.params import VLMTrainingParams HIDDEN_PARAMS = [ @@ -131,6 +132,14 @@ mixed_precision="fp16", log="tensorboard", ).model_dump() +PARAMS["vlm"] = VLMTrainingParams( + mixed_precision="fp16", + target_modules="all-linear", + log="tensorboard", + quantization="int4", + peft=True, + epochs=3, +).model_dump() @dataclass @@ -175,6 +184,8 @@ def munge(self): return self._munge_params_sent_transformers() elif self.task == "image-regression": return self._munge_params_img_reg() + elif self.task.startswith("vlm"): + return self._munge_params_vlm() else: raise ValueError(f"Unknown task: {self.task}") @@ -244,6 +255,35 @@ def _munge_params_llm(self): return LLMTrainingParams(**_params) + def _munge_params_vlm(self): + _params = self._munge_common_params() + _params["model"] = self.base_model + if not self.using_hub_dataset: + _params["text_column"] = "autotrain_text" + _params["prompt_text_column"] = "autotrain_prompt" + _params["image_column"] = "autotrain_image" + _params["valid_split"] = "validation" + else: + _params["text_column"] = self.column_mapping.get("text" if not self.api else "text_column", "text") + _params["prompt_text_column"] = self.column_mapping.get( + "prompt" if not self.api else "prompt_text_column", "prompt" + ) + _params["image_column"] = self.column_mapping.get( + "image" if not self.api else "rejected_text_column", "image" + ) + _params["train_split"] = self.train_split + _params["valid_split"] = self.valid_split + _params["log"] = "tensorboard" + + trainer = self.task.split(":")[1] + _params["trainer"] = trainer.lower() + + if "quantization" in _params: + if _params["quantization"] in ("none", "no"): + _params["quantization"] = None + + return VLMTrainingParams(**_params) + def _munge_params_text_clf(self): _params = self._munge_common_params() _params["model"] = self.base_model @@ -409,6 +449,10 @@ def get_task_params(task, param_type): trainer = task.split(":")[1].lower() task = task.split(":")[0].lower() + if task.startswith("vlm:"): + trainer = task.split(":")[1].lower() + task = task.split(":")[0].lower() + if task.startswith("tabular"): task = "tabular" @@ -506,6 +550,24 @@ def get_task_params(task, param_type): "early_stopping_threshold", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} + if task == "vlm" and param_type == "basic": + more_hidden_params = [ + "warmup_ratio", + "weight_decay", + "max_grad_norm", + "seed", + "logging_steps", + "auto_find_batch_size", + "save_total_limit", + "eval_strategy", + "early_stopping_patience", + "early_stopping_threshold", + "quantization", + "lora_r", + "lora_alpha", + "lora_dropout", + ] + task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} if task == "text-regression" and param_type == "basic": more_hidden_params = [ "warmup_ratio", diff --git a/src/autotrain/app/templates/index.html b/src/autotrain/app/templates/index.html index 9e40ff4357..75df30d244 100644 --- a/src/autotrain/app/templates/index.html +++ b/src/autotrain/app/templates/index.html @@ -38,6 +38,14 @@ fields = ['text', 'rejected_text']; fieldNames = ['chosen', 'rejected']; break; + case 'vlm:captioning': + fields = ['image', 'text']; + fieldNames = ['image', 'caption']; + break; + case 'vlm:vqa': + fields = ['image', 'prompt', 'text']; + fieldNames = ['image', 'question', 'answer']; + break; case 'st:pair': fields = ['sentence1', 'sentence2']; fieldNames = ['anchor', 'positive']; @@ -188,6 +196,10 @@ + + + + diff --git a/src/autotrain/app/ui_routes.py b/src/autotrain/app/ui_routes.py index 39c424960b..c3b26f8e8e 100644 --- a/src/autotrain/app/ui_routes.py +++ b/src/autotrain/app/ui_routes.py @@ -23,6 +23,7 @@ AutoTrainImageClassificationDataset, AutoTrainImageRegressionDataset, AutoTrainObjectDetectionDataset, + AutoTrainVLMDataset, ) from autotrain.help import get_app_help from autotrain.project import AutoTrainProject @@ -440,6 +441,8 @@ async def fetch_model_choices( hub_models = MODEL_CHOICE["image-object-detection"] elif task == "image-regression": hub_models = MODEL_CHOICE["image-regression"] + elif task.startswith("vlm:"): + hub_models = MODEL_CHOICE["vlm"] else: raise NotImplementedError @@ -571,7 +574,17 @@ async def handle_form( username=autotrain_user, local=hardware.lower() == "local-ui", ) - + elif task.startswith("vlm:"): + dset = AutoTrainVLMDataset( + train_data=training_files[0], + token=token, + project_name=project_name, + username=autotrain_user, + column_mapping=column_mapping, + valid_data=validation_files[0] if validation_files else None, + percent_valid=None, # TODO: add to UI + local=hardware.lower() == "local-ui", + ) else: if task.startswith("llm"): dset_task = "lm_training" diff --git a/src/autotrain/backends/base.py b/src/autotrain/backends/base.py index fb4475a863..73c98cb2d3 100644 --- a/src/autotrain/backends/base.py +++ b/src/autotrain/backends/base.py @@ -14,6 +14,7 @@ from autotrain.trainers.text_classification.params import TextClassificationParams from autotrain.trainers.text_regression.params import TextRegressionParams from autotrain.trainers.token_classification.params import TokenClassificationParams +from autotrain.trainers.vlm.params import VLMTrainingParams AVAILABLE_HARDWARE = { @@ -70,6 +71,7 @@ class BaseBackend: ObjectDetectionParams, SentenceTransformersParams, ImageRegressionParams, + VLMTrainingParams, ] backend: str @@ -114,6 +116,8 @@ def __post_init__(self): self.task_id = 30 elif isinstance(self.params, ImageRegressionParams): self.task_id = 24 + elif isinstance(self.params, VLMTrainingParams): + self.task_id = 31 else: raise NotImplementedError diff --git a/src/autotrain/cli/run_vlm.py b/src/autotrain/cli/run_vlm.py new file mode 100644 index 0000000000..960ce0d0ac --- /dev/null +++ b/src/autotrain/cli/run_vlm.py @@ -0,0 +1,102 @@ +from argparse import ArgumentParser + +from autotrain import logger +from autotrain.cli.utils import get_field_info, vlm_munge_data +from autotrain.project import AutoTrainProject +from autotrain.trainers.vlm.params import VLMTrainingParams + +from . import BaseAutoTrainCommand + + +def run_vlm_command_factory(args): + return RunAutoTrainVLMCommand(args) + + +class RunAutoTrainVLMCommand(BaseAutoTrainCommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + arg_list = get_field_info(VLMTrainingParams) + arg_list = [ + { + "arg": "--train", + "help": "Command to train the model", + "required": False, + "action": "store_true", + }, + { + "arg": "--deploy", + "help": "Command to deploy the model (limited availability)", + "required": False, + "action": "store_true", + }, + { + "arg": "--inference", + "help": "Command to run inference (limited availability)", + "required": False, + "action": "store_true", + }, + ] + arg_list + run_image_regression_parser = parser.add_parser("vlm", description="✨ Run AutoTrain VLM") + for arg in arg_list: + if "action" in arg: + run_image_regression_parser.add_argument( + arg["arg"], + help=arg["help"], + required=arg.get("required", False), + action=arg.get("action"), + default=arg.get("default"), + ) + else: + run_image_regression_parser.add_argument( + arg["arg"], + help=arg["help"], + required=arg.get("required", False), + type=arg.get("type"), + default=arg.get("default"), + choices=arg.get("choices"), + ) + run_image_regression_parser.set_defaults(func=run_vlm_command_factory) + + def __init__(self, args): + self.args = args + + store_true_arg_names = [ + "train", + "deploy", + "inference", + "auto_find_batch_size", + "push_to_hub", + ] + for arg_name in store_true_arg_names: + if getattr(self.args, arg_name) is None: + setattr(self.args, arg_name, False) + + if self.args.train: + if self.args.project_name is None: + raise ValueError("Project name must be specified") + if self.args.data_path is None: + raise ValueError("Data path must be specified") + if self.args.model is None: + raise ValueError("Model must be specified") + if self.args.push_to_hub: + if self.args.username is None: + raise ValueError("Username must be specified for push to hub") + else: + raise ValueError("Must specify --train, --deploy or --inference") + + if self.args.backend.startswith("spaces") or self.args.backend.startswith("ep-"): + if not self.args.push_to_hub: + raise ValueError("Push to hub must be specified for spaces backend") + if self.args.username is None: + raise ValueError("Username must be specified for spaces backend") + if self.args.token is None: + raise ValueError("Token must be specified for spaces backend") + + def run(self): + logger.info("Running Image Regression") + if self.args.train: + params = VLMTrainingParams(**vars(self.args)) + params = vlm_munge_data(params, local=self.args.backend.startswith("local")) + project = AutoTrainProject(params=params, backend=self.args.backend) + job_id = project.create() + logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/utils.py b/src/autotrain/cli/utils.py index e706654eb8..59d8a8b46e 100644 --- a/src/autotrain/cli/utils.py +++ b/src/autotrain/cli/utils.py @@ -8,6 +8,7 @@ AutoTrainImageClassificationDataset, AutoTrainImageRegressionDataset, AutoTrainObjectDetectionDataset, + AutoTrainVLMDataset, ) @@ -548,3 +549,30 @@ def img_reg_munge_data(params, local): params.image_column = "autotrain_image" params.target_column = "autotrain_label" return params + + +def vlm_munge_data(params, local): + train_data_path = f"{params.data_path}/{params.train_split}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}" + else: + valid_data_path = None + if os.path.exists(train_data_path): + col_map = {"text": params.text_column} + if params.prompt_text_column is not None: + col_map["prompt"] = params.prompt_text_column + dset = AutoTrainVLMDataset( + train_data=train_data_path, + token=params.token, + project_name=params.project_name, + username=params.username, + column_mapping=col_map, + valid_data=valid_data_path if valid_data_path is not None else None, + percent_valid=None, # TODO: add to UI + local=local, + ) + params.data_path = dset.prepare() + params.text_column = "autotrain_text" + params.image_column = "autotrain_image" + params.prompt_text_column = "autotrain_prompt" + return params diff --git a/src/autotrain/commands.py b/src/autotrain/commands.py index 84893a50f2..1af70eb577 100644 --- a/src/autotrain/commands.py +++ b/src/autotrain/commands.py @@ -16,6 +16,7 @@ from autotrain.trainers.text_classification.params import TextClassificationParams from autotrain.trainers.text_regression.params import TextRegressionParams from autotrain.trainers.token_classification.params import TokenClassificationParams +from autotrain.trainers.vlm.params import VLMTrainingParams def launch_command(params): @@ -394,6 +395,83 @@ def launch_command(params): ] ) + elif isinstance(params, VLMTrainingParams): + if num_gpus == 0: + logger.warning("No GPU found. Forcing training on CPU. This will be super slow!") + cmd = [ + "accelerate", + "launch", + "--cpu", + ] + elif num_gpus == 1: + cmd = [ + "accelerate", + "launch", + "--num_machines", + "1", + "--num_processes", + "1", + ] + elif num_gpus == 2: + cmd = [ + "accelerate", + "launch", + "--multi_gpu", + "--num_machines", + "1", + "--num_processes", + "2", + ] + else: + if params.quantization in ("int8", "int4") and params.peft and params.mixed_precision == "bf16": + cmd = [ + "accelerate", + "launch", + "--multi_gpu", + "--num_machines", + "1", + "--num_processes", + str(num_gpus), + ] + else: + cmd = [ + "accelerate", + "launch", + "--use_deepspeed", + "--zero_stage", + "3", + "--offload_optimizer_device", + "none", + "--offload_param_device", + "none", + "--zero3_save_16bit_model", + "true", + "--zero3_init_flag", + "true", + "--deepspeed_multinode_launcher", + "standard", + "--gradient_accumulation_steps", + str(params.gradient_accumulation), + ] + + if num_gpus > 0: + cmd.append("--mixed_precision") + if params.mixed_precision == "fp16": + cmd.append("fp16") + elif params.mixed_precision == "bf16": + cmd.append("bf16") + else: + cmd.append("no") + + cmd.extend( + [ + "-m", + "autotrain.trainers.vlm", + "--training_config", + os.path.join(params.project_name, "training_params.json"), + ] + ) + else: raise ValueError("Unsupported params type") diff --git a/src/autotrain/dataset.py b/src/autotrain/dataset.py index 661b971382..7e3ccf5c0b 100644 --- a/src/autotrain/dataset.py +++ b/src/autotrain/dataset.py @@ -29,6 +29,7 @@ ImageRegressionPreprocessor, ObjectDetectionPreprocessor, ) +from autotrain.preprocessor.vlm import VLMPreprocessor def remove_non_image_files(folder): @@ -236,6 +237,84 @@ def prepare(self): return preprocessor.prepare() +@dataclass +class AutoTrainVLMDataset: + train_data: str + token: str + project_name: str + username: str + column_mapping: Dict[str, str] + valid_data: Optional[str] = None + percent_valid: Optional[float] = None + local: bool = False + + def __str__(self) -> str: + info = f"Dataset: {self.project_name} ({self.task})\n" + info += f"Train data: {self.train_data}\n" + info += f"Valid data: {self.valid_data}\n" + return info + + def __post_init__(self): + self.task = "vlm" + if not self.valid_data and self.percent_valid is None: + self.percent_valid = 0.2 + elif self.valid_data and self.percent_valid is not None: + raise ValueError("You can only specify one of valid_data or percent_valid") + elif self.valid_data: + self.percent_valid = 0.0 + + def prepare(self): + valid_dir = None + if not isinstance(self.train_data, str): + cache_dir = os.environ.get("HF_HOME") + if not cache_dir: + cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface") + + random_uuid = uuid.uuid4() + train_dir = os.path.join(cache_dir, "autotrain", str(random_uuid)) + os.makedirs(train_dir, exist_ok=True) + self.train_data.seek(0) + content = self.train_data.read() + bytes_io = io.BytesIO(content) + + zip_ref = zipfile.ZipFile(bytes_io, "r") + zip_ref.extractall(train_dir) + # remove the __MACOSX directory + macosx_dir = os.path.join(train_dir, "__MACOSX") + if os.path.exists(macosx_dir): + os.system(f"rm -rf {macosx_dir}") + remove_non_image_files(train_dir) + if self.valid_data: + random_uuid = uuid.uuid4() + valid_dir = os.path.join(cache_dir, "autotrain", str(random_uuid)) + os.makedirs(valid_dir, exist_ok=True) + self.valid_data.seek(0) + content = self.valid_data.read() + bytes_io = io.BytesIO(content) + zip_ref = zipfile.ZipFile(bytes_io, "r") + zip_ref.extractall(valid_dir) + # remove the __MACOSX directory + macosx_dir = os.path.join(valid_dir, "__MACOSX") + if os.path.exists(macosx_dir): + os.system(f"rm -rf {macosx_dir}") + remove_non_image_files(valid_dir) + else: + train_dir = self.train_data + if self.valid_data: + valid_dir = self.valid_data + + preprocessor = VLMPreprocessor( + train_data=train_dir, + valid_data=valid_dir, + token=self.token, + project_name=self.project_name, + username=self.username, + local=self.local, + column_mapping=self.column_mapping, + ) + return preprocessor.prepare() + + @dataclass class AutoTrainImageRegressionDataset: train_data: str diff --git a/src/autotrain/parser.py b/src/autotrain/parser.py index b1f13861a5..d674525833 100644 --- a/src/autotrain/parser.py +++ b/src/autotrain/parser.py @@ -17,6 +17,7 @@ text_clf_munge_data, text_reg_munge_data, token_clf_munge_data, + vlm_munge_data, ) from autotrain.project import AutoTrainProject from autotrain.tasks import TASKS @@ -31,6 +32,7 @@ from autotrain.trainers.text_classification.params import TextClassificationParams from autotrain.trainers.text_regression.params import TextRegressionParams from autotrain.trainers.token_classification.params import TokenClassificationParams +from autotrain.trainers.vlm.params import VLMTrainingParams @dataclass @@ -62,6 +64,7 @@ def __post_init__(self): "text_token_classification": TokenClassificationParams, "sentence_transformers": SentenceTransformersParams, "image_single_column_regression": ImageRegressionParams, + "vlm": VLMTrainingParams, } self.munge_data_map = { "lm_training": llm_munge_data, @@ -75,6 +78,7 @@ def __post_init__(self): "text_single_column_regression": text_reg_munge_data, "sentence_transformers": sent_transformers_munge_data, "image_single_column_regression": img_reg_munge_data, + "vlm": vlm_munge_data, } self.task_aliases = { "llm": "lm_training", @@ -119,6 +123,8 @@ def __post_init__(self): "image_regression": "image_single_column_regression", "image-regression": "image_single_column_regression", "image-scoring": "image_single_column_regression", + "vlm:captioning": "vlm", + "vlm:vqa": "vlm", } task = self.config.get("task") self.task = self.task_aliases.get(task, task) @@ -159,6 +165,9 @@ def _parse_config(self): if self.task == "sentence_transformers": params["trainer"] = self.config["task"].split(":")[1] + if self.task == "vlm": + params["trainer"] = self.config["task"].split(":")[1] + if self.task != "dreambooth": for k, v in self.config["data"]["column_mapping"].items(): params[k] = v diff --git a/src/autotrain/preprocessor/vlm.py b/src/autotrain/preprocessor/vlm.py new file mode 100644 index 0000000000..d08bca0b15 --- /dev/null +++ b/src/autotrain/preprocessor/vlm.py @@ -0,0 +1,192 @@ +import os +import shutil +import uuid +from dataclasses import dataclass +from typing import Optional + +import pandas as pd +from datasets import Features, Image, Value, load_dataset +from sklearn.model_selection import train_test_split + + +ALLOWED_EXTENSIONS = ("jpeg", "png", "jpg", "JPG", "JPEG", "PNG") + + +@dataclass +class VLMPreprocessor: + train_data: str + username: str + project_name: str + token: str + column_mapping: dict + valid_data: Optional[str] = None + test_size: Optional[float] = 0.2 + seed: Optional[int] = 42 + local: Optional[bool] = False + + def _process_metadata(self, data_path): + metadata = pd.read_json(os.path.join(data_path, "metadata.jsonl"), lines=True) + # make sure that the metadata.jsonl file contains the required columns: file_name, objects + if "file_name" not in metadata.columns: + raise ValueError(f"{data_path}/metadata.jsonl should contain 'file_name' column.") + + col_names = list(self.column_mapping.values()) + + for col in col_names: + if col not in metadata.columns: + raise ValueError(f"{data_path}/metadata.jsonl should contain '{col}' column.") + + return metadata + + def __post_init__(self): + # Check if train data path exists + if not os.path.exists(self.train_data): + raise ValueError(f"{self.train_data} does not exist.") + + # check if self.train_data contains at least 5 image files in jpeg, png or jpg format only + train_image_files = [f for f in os.listdir(self.train_data) if f.endswith(ALLOWED_EXTENSIONS)] + if len(train_image_files) < 5: + raise ValueError(f"{self.train_data} should contain at least 5 jpeg, png or jpg files.") + + # check if self.train_data contains a metadata.jsonl file + if "metadata.jsonl" not in os.listdir(self.train_data): + raise ValueError(f"{self.train_data} should contain a metadata.jsonl file.") + + # Check if valid data path exists + if self.valid_data: + if not os.path.exists(self.valid_data): + raise ValueError(f"{self.valid_data} does not exist.") + + # check if self.valid_data contains at least 5 image files in jpeg, png or jpg format only + valid_image_files = [f for f in os.listdir(self.valid_data) if f.endswith(ALLOWED_EXTENSIONS)] + if len(valid_image_files) < 5: + raise ValueError(f"{self.valid_data} should contain at least 5 jpeg, png or jpg files.") + + # check if self.valid_data contains a metadata.jsonl file + if "metadata.jsonl" not in os.listdir(self.valid_data): + raise ValueError(f"{self.valid_data} should contain a metadata.jsonl file.") + + def split(self, df): + train_df, valid_df = train_test_split( + df, + test_size=self.test_size, + random_state=self.seed, + ) + train_df = train_df.reset_index(drop=True) + valid_df = valid_df.reset_index(drop=True) + return train_df, valid_df + + def prepare(self): + random_uuid = uuid.uuid4() + cache_dir = os.environ.get("HF_HOME") + if not cache_dir: + cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface") + data_dir = os.path.join(cache_dir, "autotrain", str(random_uuid)) + + if self.valid_data: + shutil.copytree(self.train_data, os.path.join(data_dir, "train")) + shutil.copytree(self.valid_data, os.path.join(data_dir, "validation")) + + train_metadata = self._process_metadata(os.path.join(data_dir, "train")) + valid_metadata = self._process_metadata(os.path.join(data_dir, "validation")) + + train_metadata.to_json(os.path.join(data_dir, "train", "metadata.jsonl"), orient="records", lines=True) + valid_metadata.to_json( + os.path.join(data_dir, "validation", "metadata.jsonl"), orient="records", lines=True + ) + + features = Features( + { + "image": Image(), + } + ) + for _, col_map in self.column_mapping.items(): + features[col_map] = Value(dtype="string") + + dataset = load_dataset("imagefolder", data_dir=data_dir, features=features) + + rename_dict = { + "image": "autotrain_image", + } + for col, col_map in self.column_mapping.items(): + if col == "text_column": + rename_dict[col_map] = "autotrain_text" + elif col == "prompt_text_column": + rename_dict[col_map] = "autotrain_prompt" + + dataset = dataset.rename_columns(rename_dict) + + if self.local: + dataset.save_to_disk(f"{self.project_name}/autotrain-data") + else: + dataset.push_to_hub( + f"{self.username}/autotrain-data-{self.project_name}", + private=True, + token=self.token, + ) + else: + metadata = pd.read_json(os.path.join(self.train_data, "metadata.jsonl"), lines=True) + train_df, valid_df = self.split(metadata) + + # create train and validation folders + os.makedirs(os.path.join(data_dir, "train"), exist_ok=True) + os.makedirs(os.path.join(data_dir, "validation"), exist_ok=True) + + # move images to train and validation folders + for row in train_df.iterrows(): + shutil.copy( + os.path.join(self.train_data, row[1]["file_name"]), + os.path.join(data_dir, "train", row[1]["file_name"]), + ) + + for row in valid_df.iterrows(): + shutil.copy( + os.path.join(self.train_data, row[1]["file_name"]), + os.path.join(data_dir, "validation", row[1]["file_name"]), + ) + + # save metadata.jsonl file to train and validation folders + train_df.to_json(os.path.join(data_dir, "train", "metadata.jsonl"), orient="records", lines=True) + valid_df.to_json(os.path.join(data_dir, "validation", "metadata.jsonl"), orient="records", lines=True) + + train_metadata = self._process_metadata(os.path.join(data_dir, "train")) + valid_metadata = self._process_metadata(os.path.join(data_dir, "validation")) + + train_metadata.to_json(os.path.join(data_dir, "train", "metadata.jsonl"), orient="records", lines=True) + valid_metadata.to_json( + os.path.join(data_dir, "validation", "metadata.jsonl"), orient="records", lines=True + ) + + features = Features( + { + "image": Image(), + } + ) + for _, col_map in self.column_mapping.items(): + features[col_map] = Value(dtype="string") + + dataset = load_dataset("imagefolder", data_dir=data_dir, features=features) + + rename_dict = { + "image": "autotrain_image", + } + for col, col_map in self.column_mapping.items(): + if col == "text_column": + rename_dict[col_map] = "autotrain_text" + elif col == "prompt_text_column": + rename_dict[col_map] = "autotrain_prompt" + + dataset = dataset.rename_columns(rename_dict) + + if self.local: + dataset.save_to_disk(f"{self.project_name}/autotrain-data") + else: + dataset.push_to_hub( + f"{self.username}/autotrain-data-{self.project_name}", + private=True, + token=self.token, + ) + + if self.local: + return f"{self.project_name}/autotrain-data" + return f"{self.username}/autotrain-data-{self.project_name}" diff --git a/src/autotrain/tasks.py b/src/autotrain/tasks.py index 090b421990..5a9bc0d049 100644 --- a/src/autotrain/tasks.py +++ b/src/autotrain/tasks.py @@ -10,6 +10,7 @@ "lm_training": 9, "seq2seq": 28, # 27 is reserved for generic training "sentence_transformers": 30, + "vlm": 31, } VISION_TASKS = { diff --git a/src/autotrain/trainers/vlm/__init__.py b/src/autotrain/trainers/vlm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/autotrain/trainers/vlm/__main__.py b/src/autotrain/trainers/vlm/__main__.py new file mode 100644 index 0000000000..d74a6215d2 --- /dev/null +++ b/src/autotrain/trainers/vlm/__main__.py @@ -0,0 +1,37 @@ +import argparse +import json + +from autotrain.trainers.common import monitor +from autotrain.trainers.vlm import utils +from autotrain.trainers.vlm.params import VLMTrainingParams + + +def parse_args(): + # get training_config.json from the end user + parser = argparse.ArgumentParser() + parser.add_argument("--training_config", type=str, required=True) + return parser.parse_args() + + +@monitor +def train(config): + if isinstance(config, dict): + config = VLMTrainingParams(**config) + + if not utils.check_model_support(config): + raise ValueError(f"model `{config.model}` not supported") + + if config.trainer in ("vqa", "captioning"): + from autotrain.trainers.vlm.train_vlm_generic import train as train_generic + + train_generic(config) + + else: + raise ValueError(f"trainer `{config.trainer}` not supported") + + +if __name__ == "__main__": + _args = parse_args() + training_config = json.load(open(_args.training_config)) + _config = VLMTrainingParams(**training_config) + train(_config) diff --git a/src/autotrain/trainers/vlm/dataset.py b/src/autotrain/trainers/vlm/dataset.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/autotrain/trainers/vlm/params.py b/src/autotrain/trainers/vlm/params.py new file mode 100644 index 0000000000..fa936ed3ef --- /dev/null +++ b/src/autotrain/trainers/vlm/params.py @@ -0,0 +1,54 @@ +from typing import Optional + +from pydantic import Field + +from autotrain.trainers.common import AutoTrainParams + + +class VLMTrainingParams(AutoTrainParams): + model: str = Field("google/paligemma-3b-pt-224", title="Model name") + project_name: str = Field("project-name", title="Output directory") + + # data params + data_path: str = Field("data", title="Data path") + train_split: str = Field("train", title="Train data config") + valid_split: Optional[str] = Field(None, title="Validation data config") + + # trainer params + trainer: str = Field("vqa", title="Trainer type") # captioning, vqa, segmentation, detection + log: str = Field("none", title="Logging using experiment tracking") + disable_gradient_checkpointing: bool = Field(False, title="Gradient checkpointing") + logging_steps: int = Field(-1, title="Logging steps") + eval_strategy: str = Field("epoch", title="Evaluation strategy") + save_total_limit: int = Field(1, title="Save total limit") + auto_find_batch_size: bool = Field(False, title="Auto find batch size") + mixed_precision: Optional[str] = Field(None, title="fp16, bf16, or None") + lr: float = Field(3e-5, title="Learning rate") + epochs: int = Field(1, title="Number of training epochs") + batch_size: int = Field(2, title="Training batch size") + warmup_ratio: float = Field(0.1, title="Warmup proportion") + gradient_accumulation: int = Field(4, title="Gradient accumulation steps") + optimizer: str = Field("adamw_torch", title="Optimizer") + scheduler: str = Field("linear", title="Scheduler") + weight_decay: float = Field(0.0, title="Weight decay") + max_grad_norm: float = Field(1.0, title="Max gradient norm") + seed: int = Field(42, title="Seed") + + # peft + quantization: Optional[str] = Field("int4", title="int4, int8, or None") + target_modules: Optional[str] = Field("all-linear", title="Target modules") + merge_adapter: bool = Field(False, title="Merge adapter") + peft: bool = Field(False, title="Use PEFT") + lora_r: int = Field(16, title="Lora r") + lora_alpha: int = Field(32, title="Lora alpha") + lora_dropout: float = Field(0.05, title="Lora dropout") + + # column mappings + image_column: Optional[str] = Field("image", title="Image column") + text_column: str = Field("text", title="Text (answer) column") + prompt_text_column: Optional[str] = Field("prompt", title="Prompt (prefix) column") + + # push to hub + push_to_hub: bool = Field(False, title="Push to hub") + username: Optional[str] = Field(None, title="Hugging Face Username") + token: Optional[str] = Field(None, title="Huggingface token") diff --git a/src/autotrain/trainers/vlm/train_vlm_generic.py b/src/autotrain/trainers/vlm/train_vlm_generic.py new file mode 100644 index 0000000000..fa3788cabb --- /dev/null +++ b/src/autotrain/trainers/vlm/train_vlm_generic.py @@ -0,0 +1,98 @@ +from functools import partial + +from datasets import load_dataset, load_from_disk +from transformers import AutoProcessor, Trainer, TrainingArguments +from transformers.trainer_callback import PrinterCallback + +from autotrain import logger +from autotrain.trainers.common import ALLOW_REMOTE_CODE +from autotrain.trainers.vlm import utils + + +def collate_fn(examples, config, processor): + prompts = ["answer " + example[config.prompt_text_column] for example in examples] + labels = [example[config.text_column] for example in examples] + images = [example[config.image_column].convert("RGB") for example in examples] + tokens = processor( + text=prompts, + images=images, + suffix=labels, + return_tensors="pt", + padding="longest", + tokenize_newline_separately=False, + ) + return tokens + + +def train(config): + valid_data = None + if config.data_path == f"{config.project_name}/autotrain-data": + train_data = load_from_disk(config.data_path)[config.train_split] + else: + if ":" in config.train_split: + dataset_config_name, split = config.train_split.split(":") + train_data = load_dataset( + config.data_path, + name=dataset_config_name, + split=split, + token=config.token, + ) + else: + train_data = load_dataset( + config.data_path, + split=config.train_split, + token=config.token, + ) + + if config.valid_split is not None: + if config.data_path == f"{config.project_name}/autotrain-data": + valid_data = load_from_disk(config.data_path)[config.valid_split] + else: + if ":" in config.valid_split: + dataset_config_name, split = config.valid_split.split(":") + valid_data = load_dataset( + config.data_path, + name=dataset_config_name, + split=split, + token=config.token, + ) + else: + valid_data = load_dataset( + config.data_path, + split=config.valid_split, + token=config.token, + ) + + logger.info(f"Train data: {train_data}") + logger.info(f"Valid data: {valid_data}") + + if config.trainer == "captioning": + config.prompt_text_column = "caption" + + processor = AutoProcessor.from_pretrained(config.model, token=config.token, trust_remote_code=ALLOW_REMOTE_CODE) + + logging_steps = utils.configure_logging_steps(config, train_data, valid_data) + training_args = utils.configure_training_args(config, logging_steps) + + args = TrainingArguments(**training_args) + model = utils.get_model(config) + + logger.info("creating trainer") + callbacks = utils.get_callbacks(config) + trainer_args = dict( + args=args, + model=model, + callbacks=callbacks, + ) + + col_fn = partial(collate_fn, config=config, processor=processor) + + trainer = Trainer( + **trainer_args, + train_dataset=train_data, + eval_dataset=valid_data if valid_data is not None else None, + data_collator=col_fn, + ) + trainer.remove_callback(PrinterCallback) + trainer.train() + utils.post_training_steps(config, trainer) diff --git a/src/autotrain/trainers/vlm/utils.py b/src/autotrain/trainers/vlm/utils.py new file mode 100644 index 0000000000..ee3826a782 --- /dev/null +++ b/src/autotrain/trainers/vlm/utils.py @@ -0,0 +1,329 @@ +import os + +import torch +from accelerate import PartialState +from huggingface_hub import HfApi +from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training +from transformers import AutoConfig, BitsAndBytesConfig, PaliGemmaForConditionalGeneration + +from autotrain import logger +from autotrain.trainers.common import ( + ALLOW_REMOTE_CODE, + LossLoggingCallback, + TrainStartCallback, + UploadLogs, + pause_space, + remove_autotrain_data, + save_training_params, +) + + +TARGET_MODULES = {} + +SUPPORTED_MODELS = [ + "PaliGemmaForConditionalGeneration", + # "Florence2ForConditionalGeneration", support later +] + +MODEL_CARD = """ +--- +tags: +- autotrain +- text-generation-inference +- image-text-to-text +- text-generation{peft} +library_name: transformers{base_model} +license: other{dataset_tag} +--- + +# Model Trained Using AutoTrain + +This model was trained using AutoTrain. For more information, please visit [AutoTrain](https://hf.co/docs/autotrain). + +# Usage + +```python +# you will need to adjust code if you didnt use peft + +from PIL import Image +from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor +import torch +import requests +from peft import PeftModel + +base_model_id = BASE_MODEL_ID +peft_model_id = THIS_MODEL_ID +max_new_tokens = 100 +text = "Whats on the flower?" +img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/bee.JPG?download=true" +image = Image.open(requests.get(img_url, stream=True).raw) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +base_model = PaliGemmaForConditionalGeneration.from_pretrained(base_model_id) +processor = PaliGemmaProcessor.from_pretrained(base_model_id) + +model = PeftModel.from_pretrained(base_model, peft_model_id) +model.merge_and_unload() + +model = model.eval().to(device) + +inputs = processor(text=text, images=image, return_tensors="pt").to(device) +with torch.inference_mode(): + generated_ids = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=False, + ) +result = processor.batch_decode(generated_ids, skip_special_tokens=True) +print(result) +``` +""" + + +def get_target_modules(config): + if config.target_modules is None: + return TARGET_MODULES.get(config.model) + if config.target_modules.strip() == "": + return TARGET_MODULES.get(config.model) + if config.target_modules.strip().lower() == "all-linear": + return "all-linear" + return config.target_modules.split(",") + + +def create_model_card(config): + if config.peft: + peft = "\n- peft" + else: + peft = "" + + if config.data_path == f"{config.project_name}/autotrain-data" or os.path.isdir(config.data_path): + dataset_tag = "" + else: + dataset_tag = f"\ndatasets:\n- {config.data_path}" + + if os.path.isdir(config.model): + base_model = "" + else: + base_model = f"\nbase_model: {config.model}" + + model_card = MODEL_CARD.format( + dataset_tag=dataset_tag, + peft=peft, + base_model=base_model, + ) + return model_card.strip() + + +def check_model_support(config): + api = HfApi(token=config.token) + model_info = api.model_info(config.model) + architectures = model_info.config.get("architectures", []) + for arch in architectures: + if arch in SUPPORTED_MODELS: + return True + return False + + +def configure_logging_steps(config, train_data, valid_data): + logger.info("configuring logging steps") + if config.logging_steps == -1: + if config.valid_split is not None: + logging_steps = int(0.2 * len(valid_data) / config.batch_size) + else: + logging_steps = int(0.2 * len(train_data) / config.batch_size) + if logging_steps == 0: + logging_steps = 1 + if logging_steps > 25: + logging_steps = 25 + config.logging_steps = logging_steps + else: + logging_steps = config.logging_steps + logger.info(f"Logging steps: {logging_steps}") + return logging_steps + + +def configure_training_args(config, logging_steps): + logger.info("configuring training args") + training_args = dict( + output_dir=config.project_name, + per_device_train_batch_size=config.batch_size, + per_device_eval_batch_size=config.batch_size, + learning_rate=config.lr, + num_train_epochs=config.epochs, + eval_strategy=config.eval_strategy if config.valid_split is not None else "no", + logging_steps=logging_steps, + save_total_limit=config.save_total_limit, + save_strategy=config.eval_strategy if config.valid_split is not None else "no", + gradient_accumulation_steps=config.gradient_accumulation, + report_to=config.log, + auto_find_batch_size=config.auto_find_batch_size, + lr_scheduler_type=config.scheduler, + optim=config.optimizer, + warmup_ratio=config.warmup_ratio, + weight_decay=config.weight_decay, + max_grad_norm=config.max_grad_norm, + push_to_hub=False, + load_best_model_at_end=True if config.valid_split is not None else False, + ddp_find_unused_parameters=False, + gradient_checkpointing=not config.disable_gradient_checkpointing, + remove_unused_columns=False, + ) + + if not config.disable_gradient_checkpointing: + if config.peft and config.quantization in ("int4", "int8"): + training_args["gradient_checkpointing_kwargs"] = {"use_reentrant": True} + else: + training_args["gradient_checkpointing_kwargs"] = {"use_reentrant": False} + + if config.mixed_precision == "fp16": + training_args["fp16"] = True + if config.mixed_precision == "bf16": + training_args["bf16"] = True + + return training_args + + +def get_callbacks(config): + callbacks = [UploadLogs(config=config), LossLoggingCallback(), TrainStartCallback()] + return callbacks + + +def get_model(config): + logger.info("loading model config...") + model_config = AutoConfig.from_pretrained( + config.model, + token=config.token, + trust_remote_code=ALLOW_REMOTE_CODE, + use_cache=config.disable_gradient_checkpointing, + ) + + logger.info("loading model...") + if config.peft: + if config.quantization == "int4": + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=False, + ) + elif config.quantization == "int8": + bnb_config = BitsAndBytesConfig(load_in_8bit=True) + else: + bnb_config = None + + model = PaliGemmaForConditionalGeneration.from_pretrained( + config.model, + config=model_config, + token=config.token, + quantization_config=bnb_config, + trust_remote_code=ALLOW_REMOTE_CODE, + ) + else: + model = PaliGemmaForConditionalGeneration.from_pretrained( + config.model, + config=model_config, + token=config.token, + trust_remote_code=ALLOW_REMOTE_CODE, + ) + + logger.info(f"model dtype: {model.dtype}") + + if config.peft: + logger.info("preparing peft model...") + if config.quantization is not None: + gradient_checkpointing_kwargs = {} + if not config.disable_gradient_checkpointing: + if config.quantization in ("int4", "int8"): + gradient_checkpointing_kwargs = {"use_reentrant": True} + else: + gradient_checkpointing_kwargs = {"use_reentrant": False} + model = prepare_model_for_kbit_training( + model, + use_gradient_checkpointing=not config.disable_gradient_checkpointing, + gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, + ) + else: + model.enable_input_require_grads() + + peft_config = LoraConfig( + r=config.lora_r, + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + bias="none", + task_type="CAUSAL_LM", + target_modules=get_target_modules(config), + ) + model = get_peft_model(model, peft_config) + + for param in model.vision_tower.parameters(): + param.requires_grad = False + + for param in model.multi_modal_projector.parameters(): + param.requires_grad = False + + return model + + +def merge_adapter(base_model_path, target_model_path, adapter_path): + logger.info("Loading adapter...") + model = PaliGemmaForConditionalGeneration.from_pretrained( + base_model_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + trust_remote_code=ALLOW_REMOTE_CODE, + ) + + model = PeftModel.from_pretrained(model, adapter_path) + model = model.merge_and_unload() + + logger.info("Saving target model...") + model.save_pretrained(target_model_path) + + +def post_training_steps(config, trainer): + logger.info("Finished training, saving model...") + trainer.model.config.use_cache = True + trainer.save_model(config.project_name) + + model_card = create_model_card(config) + + # save model card to output directory as README.md + with open(f"{config.project_name}/README.md", "w", encoding="utf-8") as f: + f.write(model_card) + + if config.peft and config.merge_adapter: + logger.info("Merging adapter weights...") + try: + del trainer + torch.cuda.empty_cache() + merge_adapter( + base_model_path=config.model, + target_model_path=config.project_name, + adapter_path=config.project_name, + ) + # remove adapter weights: adapter_* + for file in os.listdir(config.project_name): + if file.startswith("adapter_"): + os.remove(f"{config.project_name}/{file}") + except Exception as e: + logger.warning(f"Failed to merge adapter weights: {e}") + logger.warning("Skipping adapter merge. Only adapter weights will be saved.") + + if config.push_to_hub: + if PartialState().process_index == 0: + # remove data folder + remove_autotrain_data(config) + logger.info("Pushing model to hub...") + save_training_params(config) + api = HfApi(token=config.token) + api.create_repo( + repo_id=f"{config.username}/{config.project_name}", repo_type="model", private=True, exist_ok=True + ) + api.upload_folder( + folder_path=config.project_name, + repo_id=f"{config.username}/{config.project_name}", + repo_type="model", + ) + + if PartialState().process_index == 0: + pause_space(config) diff --git a/src/autotrain/utils.py b/src/autotrain/utils.py index 09917c1ff7..22e6cb7986 100644 --- a/src/autotrain/utils.py +++ b/src/autotrain/utils.py @@ -15,6 +15,7 @@ from autotrain.trainers.text_classification.params import TextClassificationParams from autotrain.trainers.text_regression.params import TextRegressionParams from autotrain.trainers.token_classification.params import TokenClassificationParams +from autotrain.trainers.vlm.params import VLMTrainingParams ALLOW_REMOTE_CODE = os.environ.get("ALLOW_REMOTE_CODE", "true").lower() == "true" @@ -48,6 +49,8 @@ def run_training(params, task_id, local=False, wait=False): params = SentenceTransformersParams(**params) elif task_id == 24: params = ImageRegressionParams(**params) + elif task_id == 31: + params = VLMTrainingParams(**params) else: raise NotImplementedError