diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 3e74d4086b..e51f2509cf 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -11,6 +11,8 @@ - sections: - local: quickstart_spaces title: Train on Spaces + - local: quickstart_py + title: Python SDK - local: quickstart title: Train Locally - local: config diff --git a/docs/source/faq.mdx b/docs/source/faq.mdx index e3f789a16e..e5130de439 100644 --- a/docs/source/faq.mdx +++ b/docs/source/faq.mdx @@ -16,6 +16,14 @@ You can safely remove the dataset from the Hub after training is complete. If uploaded, the dataset will be stored in your Hugging Face account as a private repository and will only be accessible by you and the training process. It is not used once the training is complete. +## My training space paused for no reason mid-training + +AutoTrain Training Spaces will pause itself after training is done (or failed). This is done to save resources and costs. +If your training failed, you can still see the space logs and find out what went wrong. Note: you won't be able to retrive the logs if you restart the space. + +Another reason for the space to pause is if the space is space's sleep time kicking in. If you have a long running training job, you must set the sleep time to a much higher value. +The space will anyways pause itself after the training is done thus saving you costs. + ## I get error `Your installed package nvidia-ml-py is corrupted. Skip patch functions` This error can be safely ignored. It is a warning from the `nvitop` library and does not affect the functionality of AutoTrain. diff --git a/docs/source/quickstart.mdx b/docs/source/quickstart.mdx index 51028cc54c..725ea67ff5 100644 --- a/docs/source/quickstart.mdx +++ b/docs/source/quickstart.mdx @@ -1,4 +1,4 @@ -# Quickstart +# Quickstart Guide for Local Training This quickstart is for local installation and usage. If you want to use AutoTrain on Hugging Face Spaces, please refer to the *AutoTrain on Hugging Face Spaces* section. diff --git a/docs/source/quickstart_py.mdx b/docs/source/quickstart_py.mdx new file mode 100644 index 0000000000..9bc2a5f1cd --- /dev/null +++ b/docs/source/quickstart_py.mdx @@ -0,0 +1,111 @@ +# Quickstart with Python + +AutoTrain is a library that allows you to train state of the art models on Hugging Face Spaces, or locally. +It provides a simple and easy-to-use interface to train models for various tasks like llm finetuning, text classification, +image classification, object detection, and more. + +In this quickstart guide, we will show you how to train a model using AutoTrain in Python. + +## Getting Started + +AutoTrain can be installed using pip: + +```bash +$ pip install autotrain-advanced +``` + +The example code below shows how to finetune an LLM model using AutoTrain in Python: + +```python +import os + +from autotrain.params import LLMTrainingParams +from autotrain.project import AutoTrainProject + + +params = LLMTrainingParams( + model="meta-llama/Llama-3.2-1B-Instruct", + data_path="HuggingFaceH4/no_robots", + chat_template="tokenizer", + text_column="messages", + train_split="train", + trainer="sft", + epochs=3, + batch_size=1, + lr=1e-5, + peft=True, + quantization="int4", + target_modules="all-linear", + padding="right", + optimizer="paged_adamw_8bit", + scheduler="cosine", + gradient_accumulation=8, + mixed_precision="bf16", + merge_adapter=True, + project_name="autotrain-llama32-1b-finetune", + log="tensorboard", + push_to_hub=True, + username=os.environ.get("HF_USERNAME"), + token=os.environ.get("HF_TOKEN"), +) + + +backend = "local" +project = AutoTrainProject(params=params, backend=backend, process=True) +project.create() +``` + +In this example, we are finetuning the `meta-llama/Llama-3.2-1B-Instruct` model on the `HuggingFaceH4/no_robots` dataset. +We are training the model for 3 epochs with a batch size of 1 and a learning rate of `1e-5`. +We are using the `paged_adamw_8bit` optimizer and the `cosine` scheduler. +We are also using mixed precision training with a gradient accumulation of 8. +The final model will be pushed to the Hugging Face Hub after training. + +To train the model, run the following command: + +```bash +$ export HF_USERNAME= +$ export HF_TOKEN= +$ python train.py +``` + +This will create a new project directory with the name `autotrain-llama32-1b-finetune` and start the training process. +Once the training is complete, the model will be pushed to the Hugging Face Hub. + +Your HF_TOKEN and HF_USERNAME are only required if you want to push the model or if you are accessing a gated model or dataset. + +## AutoTrainProject Class + +[[autodoc]] project.AutoTrainProject + +## Parameters + +### Text Tasks + +[[autodoc]] trainers.clm.params.LLMTrainingParams + +[[autodoc]] trainers.sent_transformers.params.SentenceTransformersParams + +[[autodoc]] trainers.seq2seq.params.Seq2SeqParams + +[[autodoc]] trainers.token_classification.params.TokenClassificationParams + +[[autodoc]] trainers.extractive_question_answering.params.ExtractiveQuestionAnsweringParams + +[[autodoc]] trainers.text_classification.params.TextClassificationParams + +[[autodoc]] trainers.text_regression.params.TextRegressionParams + +### Image Tasks + +[[autodoc]] trainers.image_classification.params.ImageClassificationParams + +[[autodoc]] trainers.image_regression.params.ImageRegressionParams + +[[autodoc]] trainers.object_detection.params.ObjectDetectionParams + +[[autodoc]] trainers.dreambooth.params.DreamBoothTrainingParams + +### Tabular Tasks + +[[autodoc]] trainers.tabular.params.TabularParams \ No newline at end of file diff --git a/docs/source/tasks/sentence_transformer.mdx b/docs/source/tasks/sentence_transformer.mdx index d007ea0835..ba7f68af01 100644 --- a/docs/source/tasks/sentence_transformer.mdx +++ b/docs/source/tasks/sentence_transformer.mdx @@ -68,3 +68,8 @@ For `qa` training, the data should be in the following format: | how are you | I am fine | | What is your name? | My name is Abhishek | | Which is the best programming language? | Python | + + +## Parameters + +[[autodoc]] trainers.sent_transformers.params.SentenceTransformersParams diff --git a/notebooks/python_example.py b/notebooks/python_example.py new file mode 100644 index 0000000000..59cde521f7 --- /dev/null +++ b/notebooks/python_example.py @@ -0,0 +1,36 @@ +import os + +from autotrain.params import LLMTrainingParams +from autotrain.project import AutoTrainProject + + +params = LLMTrainingParams( + model="meta-llama/Llama-3.2-1B-Instruct", + data_path="HuggingFaceH4/no_robots", + chat_template="tokenizer", + text_column="messages", + train_split="train", + trainer="sft", + epochs=3, + batch_size=1, + lr=1e-5, + peft=True, + quantization="int4", + target_modules="all-linear", + padding="right", + optimizer="paged_adamw_8bit", + scheduler="cosine", + gradient_accumulation=8, + mixed_precision="bf16", + merge_adapter=True, + project_name="autotrain-llama32-1b-finetune", + log="tensorboard", + push_to_hub=False, + username=os.environ.get("HF_USERNAME"), + token=os.environ.get("HF_TOKEN"), +) + + +backend = "local" +project = AutoTrainProject(params=params, backend=backend, process=True) +project.create() diff --git a/setup.cfg b/setup.cfg index 43a79d51da..b0887d9e12 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,6 +17,7 @@ max-line-length = 119 per-file-ignores = # imported but unused __init__.py: F401, E402 + src/autotrain/params.py: F401 exclude = .git, .venv, diff --git a/src/autotrain/cli/run_dreambooth.py b/src/autotrain/cli/run_dreambooth.py index 91a6b23531..31589d4b06 100644 --- a/src/autotrain/cli/run_dreambooth.py +++ b/src/autotrain/cli/run_dreambooth.py @@ -4,7 +4,7 @@ from autotrain import logger from autotrain.cli import BaseAutoTrainCommand -from autotrain.cli.utils import common_args, dreambooth_munge_data +from autotrain.cli.utils import common_args from autotrain.project import AutoTrainProject from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams from autotrain.trainers.dreambooth.utils import VALID_IMAGE_EXTENSIONS, XL_MODELS @@ -387,7 +387,6 @@ def __init__(self, args): def run(self): logger.info("Running DreamBooth Training") params = DreamBoothTrainingParams(**vars(self.args)) - params = dreambooth_munge_data(params, local=self.args.backend.startswith("local")) - project = AutoTrainProject(params=params, backend=self.args.backend) + project = AutoTrainProject(params=params, backend=self.args.backend, process=True) job_id = project.create() logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/run_extractive_qa.py b/src/autotrain/cli/run_extractive_qa.py index 1dbe2410d6..6062fbb345 100644 --- a/src/autotrain/cli/run_extractive_qa.py +++ b/src/autotrain/cli/run_extractive_qa.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from autotrain import logger -from autotrain.cli.utils import ext_qa_munge_data, get_field_info +from autotrain.cli.utils import get_field_info from autotrain.project import AutoTrainProject from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams @@ -100,7 +100,6 @@ def run(self): logger.info("Running Extractive Question Answering") if self.args.train: params = ExtractiveQuestionAnsweringParams(**vars(self.args)) - params = ext_qa_munge_data(params, local=self.args.backend.startswith("local")) - project = AutoTrainProject(params=params, backend=self.args.backend) + project = AutoTrainProject(params=params, backend=self.args.backend, process=True) job_id = project.create() logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/run_image_classification.py b/src/autotrain/cli/run_image_classification.py index 5852a4d1f9..64430f7a2a 100644 --- a/src/autotrain/cli/run_image_classification.py +++ b/src/autotrain/cli/run_image_classification.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from autotrain import logger -from autotrain.cli.utils import get_field_info, img_clf_munge_data +from autotrain.cli.utils import get_field_info from autotrain.project import AutoTrainProject from autotrain.trainers.image_classification.params import ImageClassificationParams @@ -108,7 +108,6 @@ def run(self): logger.info("Running Image Classification") if self.args.train: params = ImageClassificationParams(**vars(self.args)) - params = img_clf_munge_data(params, local=self.args.backend.startswith("local")) - project = AutoTrainProject(params=params, backend=self.args.backend) + project = AutoTrainProject(params=params, backend=self.args.backend, process=True) job_id = project.create() logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/run_image_regression.py b/src/autotrain/cli/run_image_regression.py index 0f039dee3d..713a5b2c59 100644 --- a/src/autotrain/cli/run_image_regression.py +++ b/src/autotrain/cli/run_image_regression.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from autotrain import logger -from autotrain.cli.utils import get_field_info, img_reg_munge_data +from autotrain.cli.utils import get_field_info from autotrain.project import AutoTrainProject from autotrain.trainers.image_regression.params import ImageRegressionParams @@ -108,7 +108,6 @@ def run(self): logger.info("Running Image Regression") if self.args.train: params = ImageRegressionParams(**vars(self.args)) - params = img_reg_munge_data(params, local=self.args.backend.startswith("local")) - project = AutoTrainProject(params=params, backend=self.args.backend) + project = AutoTrainProject(params=params, backend=self.args.backend, process=True) job_id = project.create() logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/run_llm.py b/src/autotrain/cli/run_llm.py index 6e1ba426f3..c2d3236cb7 100644 --- a/src/autotrain/cli/run_llm.py +++ b/src/autotrain/cli/run_llm.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from autotrain import logger -from autotrain.cli.utils import get_field_info, llm_munge_data +from autotrain.cli.utils import get_field_info from autotrain.project import AutoTrainProject from autotrain.trainers.clm.params import LLMTrainingParams @@ -136,7 +136,6 @@ def run(self): logger.info("Running LLM") if self.args.train: params = LLMTrainingParams(**vars(self.args)) - params = llm_munge_data(params, local=self.args.backend.startswith("local")) - project = AutoTrainProject(params=params, backend=self.args.backend) + project = AutoTrainProject(params=params, backend=self.args.backend, process=True) job_id = project.create() logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/run_object_detection.py b/src/autotrain/cli/run_object_detection.py index 3415fba64d..3fd63fbc4a 100644 --- a/src/autotrain/cli/run_object_detection.py +++ b/src/autotrain/cli/run_object_detection.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from autotrain import logger -from autotrain.cli.utils import get_field_info, img_obj_detect_munge_data +from autotrain.cli.utils import get_field_info from autotrain.project import AutoTrainProject from autotrain.trainers.object_detection.params import ObjectDetectionParams @@ -108,7 +108,6 @@ def run(self): logger.info("Running Object Detection") if self.args.train: params = ObjectDetectionParams(**vars(self.args)) - params = img_obj_detect_munge_data(params, local=self.args.backend.startswith("local")) - project = AutoTrainProject(params=params, backend=self.args.backend) + project = AutoTrainProject(params=params, backend=self.args.backend, process=True) job_id = project.create() logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/run_sent_tranformers.py b/src/autotrain/cli/run_sent_tranformers.py index 2bceda6e85..a6858dee2f 100644 --- a/src/autotrain/cli/run_sent_tranformers.py +++ b/src/autotrain/cli/run_sent_tranformers.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from autotrain import logger -from autotrain.cli.utils import get_field_info, sent_transformers_munge_data +from autotrain.cli.utils import get_field_info from autotrain.project import AutoTrainProject from autotrain.trainers.sent_transformers.params import SentenceTransformersParams @@ -108,7 +108,6 @@ def run(self): logger.info("Running Sentence Transformers...") if self.args.train: params = SentenceTransformersParams(**vars(self.args)) - params = sent_transformers_munge_data(params, local=self.args.backend.startswith("local")) - project = AutoTrainProject(params=params, backend=self.args.backend) + project = AutoTrainProject(params=params, backend=self.args.backend, process=True) job_id = project.create() logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/run_seq2seq.py b/src/autotrain/cli/run_seq2seq.py index 13120ceda2..0a7aaef0d5 100644 --- a/src/autotrain/cli/run_seq2seq.py +++ b/src/autotrain/cli/run_seq2seq.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from autotrain import logger -from autotrain.cli.utils import get_field_info, seq2seq_munge_data +from autotrain.cli.utils import get_field_info from autotrain.project import AutoTrainProject from autotrain.trainers.seq2seq.params import Seq2SeqParams @@ -92,7 +92,6 @@ def run(self): logger.info("Running Seq2Seq Classification") if self.args.train: params = Seq2SeqParams(**vars(self.args)) - params = seq2seq_munge_data(params, local=self.args.backend.startswith("local")) - project = AutoTrainProject(params=params, backend=self.args.backend) + project = AutoTrainProject(params=params, backend=self.args.backend, process=True) job_id = project.create() logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/run_tabular.py b/src/autotrain/cli/run_tabular.py index 06f0ef9c1e..8b1b72ee8e 100644 --- a/src/autotrain/cli/run_tabular.py +++ b/src/autotrain/cli/run_tabular.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from autotrain import logger -from autotrain.cli.utils import get_field_info, tabular_munge_data +from autotrain.cli.utils import get_field_info from autotrain.project import AutoTrainProject from autotrain.trainers.tabular.params import TabularParams @@ -101,7 +101,6 @@ def run(self): logger.info("Running Tabular Training") if self.args.train: params = TabularParams(**vars(self.args)) - params = tabular_munge_data(params, local=self.args.backend.startswith("local")) - project = AutoTrainProject(params=params, backend=self.args.backend) + project = AutoTrainProject(params=params, backend=self.args.backend, process=True) job_id = project.create() logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/run_text_classification.py b/src/autotrain/cli/run_text_classification.py index 9ae83ef6c7..79a1a6f4af 100644 --- a/src/autotrain/cli/run_text_classification.py +++ b/src/autotrain/cli/run_text_classification.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from autotrain import logger -from autotrain.cli.utils import get_field_info, text_clf_munge_data +from autotrain.cli.utils import get_field_info from autotrain.project import AutoTrainProject from autotrain.trainers.text_classification.params import TextClassificationParams @@ -101,7 +101,6 @@ def run(self): logger.info("Running Text Classification") if self.args.train: params = TextClassificationParams(**vars(self.args)) - params = text_clf_munge_data(params, local=self.args.backend.startswith("local")) - project = AutoTrainProject(params=params, backend=self.args.backend) + project = AutoTrainProject(params=params, backend=self.args.backend, process=True) job_id = project.create() logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/run_text_regression.py b/src/autotrain/cli/run_text_regression.py index 0d0151e4ba..a49c5ec070 100644 --- a/src/autotrain/cli/run_text_regression.py +++ b/src/autotrain/cli/run_text_regression.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from autotrain import logger -from autotrain.cli.utils import get_field_info, text_reg_munge_data +from autotrain.cli.utils import get_field_info from autotrain.project import AutoTrainProject from autotrain.trainers.text_regression.params import TextRegressionParams @@ -101,7 +101,6 @@ def run(self): logger.info("Running Text Regression") if self.args.train: params = TextRegressionParams(**vars(self.args)) - params = text_reg_munge_data(params, local=self.args.backend.startswith("local")) - project = AutoTrainProject(params=params, backend=self.args.backend) + project = AutoTrainProject(params=params, backend=self.args.backend, process=True) job_id = project.create() logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/run_token_classification.py b/src/autotrain/cli/run_token_classification.py index 5663f48eaa..15f5cb2438 100644 --- a/src/autotrain/cli/run_token_classification.py +++ b/src/autotrain/cli/run_token_classification.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from autotrain import logger -from autotrain.cli.utils import get_field_info, token_clf_munge_data +from autotrain.cli.utils import get_field_info from autotrain.project import AutoTrainProject from autotrain.trainers.token_classification.params import TokenClassificationParams @@ -101,7 +101,6 @@ def run(self): logger.info("Running Token Classification") if self.args.train: params = TokenClassificationParams(**vars(self.args)) - params = token_clf_munge_data(params, local=self.args.backend.startswith("local")) - project = AutoTrainProject(params=params, backend=self.args.backend) + project = AutoTrainProject(params=params, backend=self.args.backend, process=True) job_id = project.create() logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/run_vlm.py b/src/autotrain/cli/run_vlm.py index 39a7f5f7f5..5f7a93e28d 100644 --- a/src/autotrain/cli/run_vlm.py +++ b/src/autotrain/cli/run_vlm.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from autotrain import logger -from autotrain.cli.utils import get_field_info, vlm_munge_data +from autotrain.cli.utils import get_field_info from autotrain.project import AutoTrainProject from autotrain.trainers.vlm.params import VLMTrainingParams @@ -106,7 +106,6 @@ def run(self): logger.info("Running Image Regression") if self.args.train: params = VLMTrainingParams(**vars(self.args)) - params = vlm_munge_data(params, local=self.args.backend.startswith("local")) - project = AutoTrainProject(params=params, backend=self.args.backend) + project = AutoTrainProject(params=params, backend=self.args.backend, process=True) job_id = project.create() logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/cli/utils.py b/src/autotrain/cli/utils.py index 8d51b7f65c..d95cbb8860 100644 --- a/src/autotrain/cli/utils.py +++ b/src/autotrain/cli/utils.py @@ -1,15 +1,6 @@ -import os from typing import Any, Type from autotrain.backends.base import AVAILABLE_HARDWARE -from autotrain.dataset import ( - AutoTrainDataset, - AutoTrainDreamboothDataset, - AutoTrainImageClassificationDataset, - AutoTrainImageRegressionDataset, - AutoTrainObjectDetectionDataset, - AutoTrainVLMDataset, -) def common_args(): @@ -185,426 +176,3 @@ def get_field_info(params_class): field_info.append(temp_info) return field_info - - -def tabular_munge_data(params, local): - if isinstance(params.target_columns, str): - col_map_label = [params.target_columns] - else: - col_map_label = params.target_columns - task = params.task - if task == "classification" and len(col_map_label) > 1: - task = "tabular_multi_label_classification" - elif task == "classification" and len(col_map_label) == 1: - task = "tabular_multi_class_classification" - elif task == "regression" and len(col_map_label) > 1: - task = "tabular_multi_column_regression" - elif task == "regression" and len(col_map_label) == 1: - task = "tabular_single_column_regression" - else: - raise Exception("Please select a valid task.") - - exts = ["csv", "jsonl"] - ext_to_use = None - for ext in exts: - path = f"{params.data_path}/{params.train_split}.{ext}" - if os.path.exists(path): - ext_to_use = ext - break - - train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" - if params.valid_split is not None: - valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" - else: - valid_data_path = None - if os.path.exists(train_data_path): - dset = AutoTrainDataset( - train_data=[train_data_path], - task=task, - token=params.token, - project_name=params.project_name, - username=params.username, - column_mapping={"id": params.id_column, "label": col_map_label}, - valid_data=[valid_data_path] if valid_data_path is not None else None, - percent_valid=None, # TODO: add to UI - local=local, - ext=ext_to_use, - ) - params.data_path = dset.prepare() - params.valid_split = "validation" - params.id_column = "autotrain_id" - if len(col_map_label) == 1: - params.target_columns = ["autotrain_label"] - else: - params.target_columns = [f"autotrain_label_{i}" for i in range(len(col_map_label))] - return params - - -def llm_munge_data(params, local): - exts = ["csv", "jsonl"] - ext_to_use = None - for ext in exts: - path = f"{params.data_path}/{params.train_split}.{ext}" - if os.path.exists(path): - ext_to_use = ext - break - - train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" - if params.valid_split is not None: - valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" - else: - valid_data_path = None - if os.path.exists(train_data_path): - col_map = {"text": params.text_column} - if params.rejected_text_column is not None: - col_map["rejected_text"] = params.rejected_text_column - if params.prompt_text_column is not None: - col_map["prompt"] = params.prompt_text_column - dset = AutoTrainDataset( - train_data=[train_data_path], - task="lm_training", - token=params.token, - project_name=params.project_name, - username=params.username, - column_mapping=col_map, - valid_data=[valid_data_path] if valid_data_path is not None else None, - percent_valid=None, # TODO: add to UI - local=local, - ext=ext_to_use, - ) - params.data_path = dset.prepare() - params.valid_split = None - params.text_column = "autotrain_text" - params.rejected_text_column = "autotrain_rejected_text" - params.prompt_text_column = "autotrain_prompt" - return params - - -def seq2seq_munge_data(params, local): - exts = ["csv", "jsonl"] - ext_to_use = None - for ext in exts: - path = f"{params.data_path}/{params.train_split}.{ext}" - if os.path.exists(path): - ext_to_use = ext - break - - train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" - if params.valid_split is not None: - valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" - else: - valid_data_path = None - if os.path.exists(train_data_path): - dset = AutoTrainDataset( - train_data=[train_data_path], - task="seq2seq", - token=params.token, - project_name=params.project_name, - username=params.username, - column_mapping={"text": params.text_column, "label": params.target_column}, - valid_data=[valid_data_path] if valid_data_path is not None else None, - percent_valid=None, # TODO: add to UI - local=local, - ext=ext_to_use, - ) - params.data_path = dset.prepare() - params.valid_split = "validation" - params.text_column = "autotrain_text" - params.target_column = "autotrain_label" - return params - - -def text_clf_munge_data(params, local): - exts = ["csv", "jsonl"] - ext_to_use = None - for ext in exts: - path = f"{params.data_path}/{params.train_split}.{ext}" - if os.path.exists(path): - ext_to_use = ext - break - - train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" - if params.valid_split is not None: - valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" - else: - valid_data_path = None - if os.path.exists(train_data_path): - dset = AutoTrainDataset( - train_data=[train_data_path], - valid_data=[valid_data_path] if valid_data_path is not None else None, - task="text_multi_class_classification", - token=params.token, - project_name=params.project_name, - username=params.username, - column_mapping={"text": params.text_column, "label": params.target_column}, - percent_valid=None, # TODO: add to UI - local=local, - convert_to_class_label=True, - ext=ext_to_use, - ) - params.data_path = dset.prepare() - params.valid_split = "validation" - params.text_column = "autotrain_text" - params.target_column = "autotrain_label" - return params - - -def text_reg_munge_data(params, local): - exts = ["csv", "jsonl"] - ext_to_use = None - for ext in exts: - path = f"{params.data_path}/{params.train_split}.{ext}" - if os.path.exists(path): - ext_to_use = ext - break - - train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" - if params.valid_split is not None: - valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" - else: - valid_data_path = None - if os.path.exists(train_data_path): - dset = AutoTrainDataset( - train_data=[train_data_path], - valid_data=[valid_data_path] if valid_data_path is not None else None, - task="text_single_column_regression", - token=params.token, - project_name=params.project_name, - username=params.username, - column_mapping={"text": params.text_column, "label": params.target_column}, - percent_valid=None, # TODO: add to UI - local=local, - convert_to_class_label=False, - ext=ext_to_use, - ) - params.data_path = dset.prepare() - params.valid_split = "validation" - params.text_column = "autotrain_text" - params.target_column = "autotrain_label" - return params - - -def token_clf_munge_data(params, local): - exts = ["csv", "jsonl"] - ext_to_use = None - for ext in exts: - path = f"{params.data_path}/{params.train_split}.{ext}" - if os.path.exists(path): - ext_to_use = ext - break - - train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" - if params.valid_split is not None: - valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" - else: - valid_data_path = None - if os.path.exists(train_data_path): - dset = AutoTrainDataset( - train_data=[train_data_path], - valid_data=[valid_data_path] if valid_data_path is not None else None, - task="text_token_classification", - token=params.token, - project_name=params.project_name, - username=params.username, - column_mapping={"text": params.tokens_column, "label": params.tags_column}, - percent_valid=None, # TODO: add to UI - local=local, - convert_to_class_label=True, - ext=ext_to_use, - ) - params.data_path = dset.prepare() - params.valid_split = "validation" - params.text_column = "autotrain_text" - params.target_column = "autotrain_label" - return params - - -def img_clf_munge_data(params, local): - train_data_path = f"{params.data_path}/{params.train_split}" - if params.valid_split is not None: - valid_data_path = f"{params.data_path}/{params.valid_split}" - else: - valid_data_path = None - if os.path.isdir(train_data_path): - dset = AutoTrainImageClassificationDataset( - train_data=train_data_path, - valid_data=valid_data_path, - token=params.token, - project_name=params.project_name, - username=params.username, - local=local, - ) - params.data_path = dset.prepare() - params.valid_split = "validation" - params.image_column = "autotrain_image" - params.target_column = "autotrain_label" - return params - - -def dreambooth_munge_data(params, local): - # check if params.image_path is a directory - if os.path.isdir(params.image_path): - training_data = [os.path.join(params.image_path, f) for f in os.listdir(params.image_path)] - dset = AutoTrainDreamboothDataset( - concept_images=training_data, - concept_name=params.prompt, - token=params.token, - project_name=params.project_name, - username=params.username, - local=local, - ) - params.image_path = dset.prepare() - return params - - -def img_obj_detect_munge_data(params, local): - train_data_path = f"{params.data_path}/{params.train_split}" - if params.valid_split is not None: - valid_data_path = f"{params.data_path}/{params.valid_split}" - else: - valid_data_path = None - if os.path.isdir(train_data_path): - dset = AutoTrainObjectDetectionDataset( - train_data=train_data_path, - valid_data=valid_data_path, - token=params.token, - project_name=params.project_name, - username=params.username, - local=local, - ) - params.data_path = dset.prepare() - params.valid_split = "validation" - params.image_column = "autotrain_image" - params.objects_column = "autotrain_objects" - return params - - -def sent_transformers_munge_data(params, local): - exts = ["csv", "jsonl"] - ext_to_use = None - for ext in exts: - path = f"{params.data_path}/{params.train_split}.{ext}" - if os.path.exists(path): - ext_to_use = ext - break - - train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" - if params.valid_split is not None: - valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" - else: - valid_data_path = None - if os.path.exists(train_data_path): - dset = AutoTrainDataset( - train_data=[train_data_path], - valid_data=[valid_data_path] if valid_data_path is not None else None, - task="sentence_transformers", - token=params.token, - project_name=params.project_name, - username=params.username, - column_mapping={ - "sentence1": params.sentence1_column, - "sentence2": params.sentence2_column, - "sentence3": params.sentence3_column, - "target": params.target_column, - }, - percent_valid=None, # TODO: add to UI - local=local, - convert_to_class_label=True if params.trainer == "pair_class" else False, - ext=ext_to_use, - ) - params.data_path = dset.prepare() - params.valid_split = "validation" - params.sentence1_column = "autotrain_sentence1" - params.sentence2_column = "autotrain_sentence2" - params.sentence3_column = "autotrain_sentence3" - params.target_column = "autotrain_target" - return params - - -def img_reg_munge_data(params, local): - train_data_path = f"{params.data_path}/{params.train_split}" - if params.valid_split is not None: - valid_data_path = f"{params.data_path}/{params.valid_split}" - else: - valid_data_path = None - if os.path.isdir(train_data_path): - dset = AutoTrainImageRegressionDataset( - train_data=train_data_path, - valid_data=valid_data_path, - token=params.token, - project_name=params.project_name, - username=params.username, - local=local, - ) - params.data_path = dset.prepare() - params.valid_split = "validation" - params.image_column = "autotrain_image" - params.target_column = "autotrain_label" - return params - - -def vlm_munge_data(params, local): - train_data_path = f"{params.data_path}/{params.train_split}" - if params.valid_split is not None: - valid_data_path = f"{params.data_path}/{params.valid_split}" - else: - valid_data_path = None - if os.path.exists(train_data_path): - col_map = {"text": params.text_column} - if params.prompt_text_column is not None: - col_map["prompt"] = params.prompt_text_column - dset = AutoTrainVLMDataset( - train_data=train_data_path, - token=params.token, - project_name=params.project_name, - username=params.username, - column_mapping=col_map, - valid_data=valid_data_path if valid_data_path is not None else None, - percent_valid=None, # TODO: add to UI - local=local, - ) - params.data_path = dset.prepare() - params.text_column = "autotrain_text" - params.image_column = "autotrain_image" - params.prompt_text_column = "autotrain_prompt" - return params - - -def ext_qa_munge_data(params, local): - exts = ["csv", "jsonl"] - ext_to_use = None - for ext in exts: - path = f"{params.data_path}/{params.train_split}.{ext}" - if os.path.exists(path): - ext_to_use = ext - break - - train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" - if params.valid_split is not None: - valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" - else: - valid_data_path = None - if os.path.exists(train_data_path): - dset = AutoTrainDataset( - train_data=[train_data_path], - valid_data=[valid_data_path] if valid_data_path is not None else None, - task="text_extractive_question_answering", - token=params.token, - project_name=params.project_name, - username=params.username, - column_mapping={ - "text": params.text_column, - "question": params.question_column, - "answer": params.answer_column, - }, - percent_valid=None, # TODO: add to UI - local=local, - convert_to_class_label=True, - ext=ext_to_use, - ) - params.data_path = dset.prepare() - params.valid_split = "validation" - params.text_column = "autotrain_text" - params.question_column = "autotrain_question" - params.answer_column = "autotrain_answer" - return params diff --git a/src/autotrain/params.py b/src/autotrain/params.py new file mode 100644 index 0000000000..adf49bb1de --- /dev/null +++ b/src/autotrain/params.py @@ -0,0 +1,13 @@ +from autotrain.trainers.clm.params import LLMTrainingParams +from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams +from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams +from autotrain.trainers.image_classification.params import ImageClassificationParams +from autotrain.trainers.image_regression.params import ImageRegressionParams +from autotrain.trainers.object_detection.params import ObjectDetectionParams +from autotrain.trainers.sent_transformers.params import SentenceTransformersParams +from autotrain.trainers.seq2seq.params import Seq2SeqParams +from autotrain.trainers.tabular.params import TabularParams +from autotrain.trainers.text_classification.params import TextClassificationParams +from autotrain.trainers.text_regression.params import TextRegressionParams +from autotrain.trainers.token_classification.params import TokenClassificationParams +from autotrain.trainers.vlm.params import VLMTrainingParams diff --git a/src/autotrain/parser.py b/src/autotrain/parser.py index f5e7a56486..f1f7e256d6 100644 --- a/src/autotrain/parser.py +++ b/src/autotrain/parser.py @@ -5,7 +5,8 @@ import yaml from autotrain import logger -from autotrain.cli.utils import ( +from autotrain.project import ( + AutoTrainProject, dreambooth_munge_data, ext_qa_munge_data, img_clf_munge_data, @@ -20,7 +21,6 @@ token_clf_munge_data, vlm_munge_data, ) -from autotrain.project import AutoTrainProject from autotrain.tasks import TASKS from autotrain.trainers.clm.params import LLMTrainingParams from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams diff --git a/src/autotrain/project.py b/src/autotrain/project.py index de65f153e1..121c8fd1e2 100644 --- a/src/autotrain/project.py +++ b/src/autotrain/project.py @@ -2,8 +2,9 @@ Copyright 2023 The HuggingFace Team """ +import os from dataclasses import dataclass -from typing import List, Union +from typing import Union from autotrain.backends.base import AVAILABLE_HARDWARE from autotrain.backends.endpoints import EndpointsRunner @@ -11,8 +12,17 @@ from autotrain.backends.ngc import NGCRunner from autotrain.backends.nvcf import NVCFRunner from autotrain.backends.spaces import SpaceRunner +from autotrain.dataset import ( + AutoTrainDataset, + AutoTrainDreamboothDataset, + AutoTrainImageClassificationDataset, + AutoTrainImageRegressionDataset, + AutoTrainObjectDetectionDataset, + AutoTrainVLMDataset, +) from autotrain.trainers.clm.params import LLMTrainingParams from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams +from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams from autotrain.trainers.image_classification.params import ImageClassificationParams from autotrain.trainers.image_regression.params import ImageRegressionParams from autotrain.trainers.object_detection.params import ObjectDetectionParams @@ -22,19 +32,474 @@ from autotrain.trainers.text_classification.params import TextClassificationParams from autotrain.trainers.text_regression.params import TextRegressionParams from autotrain.trainers.token_classification.params import TokenClassificationParams +from autotrain.trainers.vlm.params import VLMTrainingParams + + +def tabular_munge_data(params, local): + if isinstance(params.target_columns, str): + col_map_label = [params.target_columns] + else: + col_map_label = params.target_columns + task = params.task + if task == "classification" and len(col_map_label) > 1: + task = "tabular_multi_label_classification" + elif task == "classification" and len(col_map_label) == 1: + task = "tabular_multi_class_classification" + elif task == "regression" and len(col_map_label) > 1: + task = "tabular_multi_column_regression" + elif task == "regression" and len(col_map_label) == 1: + task = "tabular_single_column_regression" + else: + raise Exception("Please select a valid task.") + + exts = ["csv", "jsonl"] + ext_to_use = None + for ext in exts: + path = f"{params.data_path}/{params.train_split}.{ext}" + if os.path.exists(path): + ext_to_use = ext + break + + train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" + else: + valid_data_path = None + if os.path.exists(train_data_path): + dset = AutoTrainDataset( + train_data=[train_data_path], + task=task, + token=params.token, + project_name=params.project_name, + username=params.username, + column_mapping={"id": params.id_column, "label": col_map_label}, + valid_data=[valid_data_path] if valid_data_path is not None else None, + percent_valid=None, # TODO: add to UI + local=local, + ext=ext_to_use, + ) + params.data_path = dset.prepare() + params.valid_split = "validation" + params.id_column = "autotrain_id" + if len(col_map_label) == 1: + params.target_columns = ["autotrain_label"] + else: + params.target_columns = [f"autotrain_label_{i}" for i in range(len(col_map_label))] + return params + + +def llm_munge_data(params, local): + exts = ["csv", "jsonl"] + ext_to_use = None + for ext in exts: + path = f"{params.data_path}/{params.train_split}.{ext}" + if os.path.exists(path): + ext_to_use = ext + break + + train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" + else: + valid_data_path = None + if os.path.exists(train_data_path): + col_map = {"text": params.text_column} + if params.rejected_text_column is not None: + col_map["rejected_text"] = params.rejected_text_column + if params.prompt_text_column is not None: + col_map["prompt"] = params.prompt_text_column + dset = AutoTrainDataset( + train_data=[train_data_path], + task="lm_training", + token=params.token, + project_name=params.project_name, + username=params.username, + column_mapping=col_map, + valid_data=[valid_data_path] if valid_data_path is not None else None, + percent_valid=None, # TODO: add to UI + local=local, + ext=ext_to_use, + ) + params.data_path = dset.prepare() + params.valid_split = None + params.text_column = "autotrain_text" + params.rejected_text_column = "autotrain_rejected_text" + params.prompt_text_column = "autotrain_prompt" + return params + + +def seq2seq_munge_data(params, local): + exts = ["csv", "jsonl"] + ext_to_use = None + for ext in exts: + path = f"{params.data_path}/{params.train_split}.{ext}" + if os.path.exists(path): + ext_to_use = ext + break + + train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" + else: + valid_data_path = None + if os.path.exists(train_data_path): + dset = AutoTrainDataset( + train_data=[train_data_path], + task="seq2seq", + token=params.token, + project_name=params.project_name, + username=params.username, + column_mapping={"text": params.text_column, "label": params.target_column}, + valid_data=[valid_data_path] if valid_data_path is not None else None, + percent_valid=None, # TODO: add to UI + local=local, + ext=ext_to_use, + ) + params.data_path = dset.prepare() + params.valid_split = "validation" + params.text_column = "autotrain_text" + params.target_column = "autotrain_label" + return params + + +def text_clf_munge_data(params, local): + exts = ["csv", "jsonl"] + ext_to_use = None + for ext in exts: + path = f"{params.data_path}/{params.train_split}.{ext}" + if os.path.exists(path): + ext_to_use = ext + break + + train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" + else: + valid_data_path = None + if os.path.exists(train_data_path): + dset = AutoTrainDataset( + train_data=[train_data_path], + valid_data=[valid_data_path] if valid_data_path is not None else None, + task="text_multi_class_classification", + token=params.token, + project_name=params.project_name, + username=params.username, + column_mapping={"text": params.text_column, "label": params.target_column}, + percent_valid=None, # TODO: add to UI + local=local, + convert_to_class_label=True, + ext=ext_to_use, + ) + params.data_path = dset.prepare() + params.valid_split = "validation" + params.text_column = "autotrain_text" + params.target_column = "autotrain_label" + return params + + +def text_reg_munge_data(params, local): + exts = ["csv", "jsonl"] + ext_to_use = None + for ext in exts: + path = f"{params.data_path}/{params.train_split}.{ext}" + if os.path.exists(path): + ext_to_use = ext + break + + train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" + else: + valid_data_path = None + if os.path.exists(train_data_path): + dset = AutoTrainDataset( + train_data=[train_data_path], + valid_data=[valid_data_path] if valid_data_path is not None else None, + task="text_single_column_regression", + token=params.token, + project_name=params.project_name, + username=params.username, + column_mapping={"text": params.text_column, "label": params.target_column}, + percent_valid=None, # TODO: add to UI + local=local, + convert_to_class_label=False, + ext=ext_to_use, + ) + params.data_path = dset.prepare() + params.valid_split = "validation" + params.text_column = "autotrain_text" + params.target_column = "autotrain_label" + return params + + +def token_clf_munge_data(params, local): + exts = ["csv", "jsonl"] + ext_to_use = None + for ext in exts: + path = f"{params.data_path}/{params.train_split}.{ext}" + if os.path.exists(path): + ext_to_use = ext + break + + train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" + else: + valid_data_path = None + if os.path.exists(train_data_path): + dset = AutoTrainDataset( + train_data=[train_data_path], + valid_data=[valid_data_path] if valid_data_path is not None else None, + task="text_token_classification", + token=params.token, + project_name=params.project_name, + username=params.username, + column_mapping={"text": params.tokens_column, "label": params.tags_column}, + percent_valid=None, # TODO: add to UI + local=local, + convert_to_class_label=True, + ext=ext_to_use, + ) + params.data_path = dset.prepare() + params.valid_split = "validation" + params.text_column = "autotrain_text" + params.target_column = "autotrain_label" + return params + + +def img_clf_munge_data(params, local): + train_data_path = f"{params.data_path}/{params.train_split}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}" + else: + valid_data_path = None + if os.path.isdir(train_data_path): + dset = AutoTrainImageClassificationDataset( + train_data=train_data_path, + valid_data=valid_data_path, + token=params.token, + project_name=params.project_name, + username=params.username, + local=local, + ) + params.data_path = dset.prepare() + params.valid_split = "validation" + params.image_column = "autotrain_image" + params.target_column = "autotrain_label" + return params + + +def dreambooth_munge_data(params, local): + # check if params.image_path is a directory + if os.path.isdir(params.image_path): + training_data = [os.path.join(params.image_path, f) for f in os.listdir(params.image_path)] + dset = AutoTrainDreamboothDataset( + concept_images=training_data, + concept_name=params.prompt, + token=params.token, + project_name=params.project_name, + username=params.username, + local=local, + ) + params.image_path = dset.prepare() + return params + + +def img_obj_detect_munge_data(params, local): + train_data_path = f"{params.data_path}/{params.train_split}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}" + else: + valid_data_path = None + if os.path.isdir(train_data_path): + dset = AutoTrainObjectDetectionDataset( + train_data=train_data_path, + valid_data=valid_data_path, + token=params.token, + project_name=params.project_name, + username=params.username, + local=local, + ) + params.data_path = dset.prepare() + params.valid_split = "validation" + params.image_column = "autotrain_image" + params.objects_column = "autotrain_objects" + return params + + +def sent_transformers_munge_data(params, local): + exts = ["csv", "jsonl"] + ext_to_use = None + for ext in exts: + path = f"{params.data_path}/{params.train_split}.{ext}" + if os.path.exists(path): + ext_to_use = ext + break + + train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" + else: + valid_data_path = None + if os.path.exists(train_data_path): + dset = AutoTrainDataset( + train_data=[train_data_path], + valid_data=[valid_data_path] if valid_data_path is not None else None, + task="sentence_transformers", + token=params.token, + project_name=params.project_name, + username=params.username, + column_mapping={ + "sentence1": params.sentence1_column, + "sentence2": params.sentence2_column, + "sentence3": params.sentence3_column, + "target": params.target_column, + }, + percent_valid=None, # TODO: add to UI + local=local, + convert_to_class_label=True if params.trainer == "pair_class" else False, + ext=ext_to_use, + ) + params.data_path = dset.prepare() + params.valid_split = "validation" + params.sentence1_column = "autotrain_sentence1" + params.sentence2_column = "autotrain_sentence2" + params.sentence3_column = "autotrain_sentence3" + params.target_column = "autotrain_target" + return params + + +def img_reg_munge_data(params, local): + train_data_path = f"{params.data_path}/{params.train_split}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}" + else: + valid_data_path = None + if os.path.isdir(train_data_path): + dset = AutoTrainImageRegressionDataset( + train_data=train_data_path, + valid_data=valid_data_path, + token=params.token, + project_name=params.project_name, + username=params.username, + local=local, + ) + params.data_path = dset.prepare() + params.valid_split = "validation" + params.image_column = "autotrain_image" + params.target_column = "autotrain_label" + return params + + +def vlm_munge_data(params, local): + train_data_path = f"{params.data_path}/{params.train_split}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}" + else: + valid_data_path = None + if os.path.exists(train_data_path): + col_map = {"text": params.text_column} + if params.prompt_text_column is not None: + col_map["prompt"] = params.prompt_text_column + dset = AutoTrainVLMDataset( + train_data=train_data_path, + token=params.token, + project_name=params.project_name, + username=params.username, + column_mapping=col_map, + valid_data=valid_data_path if valid_data_path is not None else None, + percent_valid=None, # TODO: add to UI + local=local, + ) + params.data_path = dset.prepare() + params.text_column = "autotrain_text" + params.image_column = "autotrain_image" + params.prompt_text_column = "autotrain_prompt" + return params + + +def ext_qa_munge_data(params, local): + exts = ["csv", "jsonl"] + ext_to_use = None + for ext in exts: + path = f"{params.data_path}/{params.train_split}.{ext}" + if os.path.exists(path): + ext_to_use = ext + break + + train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" + else: + valid_data_path = None + if os.path.exists(train_data_path): + dset = AutoTrainDataset( + train_data=[train_data_path], + valid_data=[valid_data_path] if valid_data_path is not None else None, + task="text_extractive_question_answering", + token=params.token, + project_name=params.project_name, + username=params.username, + column_mapping={ + "text": params.text_column, + "question": params.question_column, + "answer": params.answer_column, + }, + percent_valid=None, # TODO: add to UI + local=local, + convert_to_class_label=True, + ext=ext_to_use, + ) + params.data_path = dset.prepare() + params.valid_split = "validation" + params.text_column = "autotrain_text" + params.question_column = "autotrain_question" + params.answer_column = "autotrain_answer" + return params @dataclass class AutoTrainProject: """ - A class to represent an AutoTrain project + A class to train an AutoTrain project Attributes ---------- - params : Union[List[Union[LLMTrainingParams, TextClassificationParams, TabularParams, DreamBoothTrainingParams, Seq2SeqParams, ImageClassificationParams, TextRegressionParams, ObjectDetectionParams, TokenClassificationParams, SentenceTransformersParams, ImageRegressionParams]], LLMTrainingParams, TextClassificationParams, TabularParams, DreamBoothTrainingParams, Seq2SeqParams, ImageClassificationParams, TextRegressionParams, ObjectDetectionParams, TokenClassificationParams, SentenceTransformersParams, ImageRegressionParams] + params : Union[ + LLMTrainingParams, + TextClassificationParams, + TabularParams, + DreamBoothTrainingParams, + Seq2SeqParams, + ImageClassificationParams, + TextRegressionParams, + ObjectDetectionParams, + TokenClassificationParams, + SentenceTransformersParams, + ImageRegressionParams, + ExtractiveQuestionAnsweringParams, + VLMTrainingParams, + ] The parameters for the AutoTrain project. backend : str - The backend to be used for the AutoTrain project. + The backend to be used for the AutoTrain project. It should be one of the following: + - local + - spaces-a10g-large + - spaces-a10g-small + - spaces-a100-large + - spaces-t4-medium + - spaces-t4-small + - spaces-cpu-upgrade + - spaces-cpu-basic + - spaces-l4x1 + - spaces-l4x4 + - spaces-l40sx1 + - spaces-l40sx4 + - spaces-l40sx8 + - spaces-a10g-largex2 + - spaces-a10g-largex4 + process : bool + Flag to indicate if the params and dataset should be processed. If your data format is not AutoTrain-readable, set it to True. Set it to True when in doubt. Defaults to False. Methods ------- @@ -45,21 +510,6 @@ class AutoTrainProject: """ params: Union[ - List[ - Union[ - LLMTrainingParams, - TextClassificationParams, - TabularParams, - DreamBoothTrainingParams, - Seq2SeqParams, - ImageClassificationParams, - TextRegressionParams, - ObjectDetectionParams, - TokenClassificationParams, - SentenceTransformersParams, - ImageRegressionParams, - ] - ], LLMTrainingParams, TextClassificationParams, TabularParams, @@ -71,14 +521,51 @@ class AutoTrainProject: TokenClassificationParams, SentenceTransformersParams, ImageRegressionParams, + ExtractiveQuestionAnsweringParams, + VLMTrainingParams, ] backend: str + process: bool = False def __post_init__(self): + self.local = self.backend.startswith("local") if self.backend not in AVAILABLE_HARDWARE: raise ValueError(f"Invalid backend: {self.backend}") + def _process_params_data(self): + if isinstance(self.params, LLMTrainingParams): + return llm_munge_data(self.params, self.local) + elif isinstance(self.params, DreamBoothTrainingParams): + return dreambooth_munge_data(self.params, self.local) + elif isinstance(self.params, ExtractiveQuestionAnsweringParams): + return ext_qa_munge_data(self.params, self.local) + elif isinstance(self.params, ImageClassificationParams): + return img_clf_munge_data(self.params, self.local) + elif isinstance(self.params, ImageRegressionParams): + return img_reg_munge_data(self.params, self.local) + elif isinstance(self.params, ObjectDetectionParams): + return img_obj_detect_munge_data(self.params, self.local) + elif isinstance(self.params, SentenceTransformersParams): + return sent_transformers_munge_data(self.params, self.local) + elif isinstance(self.params, Seq2SeqParams): + return seq2seq_munge_data(self.params, self.local) + elif isinstance(self.params, TabularParams): + return tabular_munge_data(self.params, self.local) + elif isinstance(self.params, TextClassificationParams): + return text_clf_munge_data(self.params, self.local) + elif isinstance(self.params, TextRegressionParams): + return text_reg_munge_data(self.params, self.local) + elif isinstance(self.params, TokenClassificationParams): + return token_clf_munge_data(self.params, self.local) + elif isinstance(self.params, VLMTrainingParams): + return vlm_munge_data(self.params, self.local) + else: + raise Exception("Invalid params class") + def create(self): + if self.process: + self.params = self._process_params_data() + if self.backend.startswith("local"): runner = LocalRunner(params=self.params, backend=self.backend) return runner.create()