From 15711ae4cb5394ffdb4e25952ebc79c50276282d Mon Sep 17 00:00:00 2001 From: abhishek thakur <1183441+abhishekkrthakur@users.noreply.github.com> Date: Mon, 29 Apr 2024 14:27:27 +0200 Subject: [PATCH] Text regression (#605) --- docs/source/_toctree.yml | 2 + docs/source/text_regression.mdx | 147 +++++++++++++ src/autotrain/app.py | 21 ++ src/autotrain/app_params.py | 18 ++ src/autotrain/app_utils.py | 3 + src/autotrain/backend.py | 8 + src/autotrain/cli/autotrain.py | 2 + src/autotrain/cli/run_text_regression.py | 170 +++++++++++++++ src/autotrain/cli/utils.py | 35 +++ src/autotrain/commands.py | 29 ++- src/autotrain/models.py | 1 + src/autotrain/preprocessor/text.py | 33 +++ src/autotrain/project.py | 3 + src/autotrain/tasks.py | 40 ---- src/autotrain/templates/index.html | 6 + .../trainers/text_regression/__init__.py | 0 .../trainers/text_regression/__main__.py | 199 ++++++++++++++++++ .../trainers/text_regression/dataset.py | 45 ++++ .../trainers/text_regression/params.py | 35 +++ .../trainers/text_regression/utils.py | 80 +++++++ 20 files changed, 828 insertions(+), 49 deletions(-) create mode 100644 docs/source/text_regression.mdx create mode 100644 src/autotrain/cli/run_text_regression.py create mode 100644 src/autotrain/trainers/text_regression/__init__.py create mode 100644 src/autotrain/trainers/text_regression/__main__.py create mode 100644 src/autotrain/trainers/text_regression/dataset.py create mode 100644 src/autotrain/trainers/text_regression/params.py create mode 100644 src/autotrain/trainers/text_regression/utils.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 68f9a6c073..726647003c 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -17,6 +17,8 @@ - sections: - local: text_classification title: Text Classification + - local: text_regression + title: Text Regression - local: llm_finetuning title: LLM Finetuning - local: image_classification diff --git a/docs/source/text_regression.mdx b/docs/source/text_regression.mdx new file mode 100644 index 0000000000..cbaa57b274 --- /dev/null +++ b/docs/source/text_regression.mdx @@ -0,0 +1,147 @@ +# Text Regression + +Training a text regression model with AutoTrain is super-easy! Get your data ready in +proper format and then with just a few clicks, your state-of-the-art model will be ready to +be used in production. + +## Data Format + +Let's train a model for scoring a movie review on a scale of 1-5. The data should be +in the following CSV format: + +```csv +text,target +"this movie is great",5 +"this movie is bad",1 +. +. +. +``` + +As you can see, we have two columns in the CSV file. One column is the text and the other +is the label. The label can be any float or int. + +If your CSV is huge, you can divide it into multiple CSV files and upload them separately. +Please make sure that the column names are the same in all CSV files. + +One way to divide the CSV file using pandas is as follows: + +```python +import pandas as pd + +# Set the chunk size +chunk_size = 1000 +i = 1 + +# Open the CSV file and read it in chunks +for chunk in pd.read_csv('example.csv', chunksize=chunk_size): + # Save each chunk to a new file + chunk.to_csv(f'chunk_{i}.csv', index=False) + i += 1 +``` + +Instead of CSV you can also use JSONL format. The JSONL format should be as follows: + +```json +{"text": "this movie is great", "target": 5} +{"text": "this movie is bad", "target": 1} +. +. +. +``` + +## Columns + +Your CSV dataset must have two columns: `text` and `target`. + + +### Params + +``` +❯ autotrain text-regression --help +usage: autotrain [] text-regression [-h] [--train] [--deploy] [--inference] [--username USERNAME] + [--backend {local-cli,spaces-a10gl,spaces-a10gs,spaces-a100,spaces-t4m,spaces-t4s,spaces-cpu,spaces-cpuf}] + [--token TOKEN] [--push-to-hub] --model MODEL --project-name PROJECT_NAME + [--data-path DATA_PATH] [--train-split TRAIN_SPLIT] [--valid-split VALID_SPLIT] + [--batch-size BATCH_SIZE] [--seed SEED] [--epochs EPOCHS] + [--gradient_accumulation GRADIENT_ACCUMULATION] [--disable_gradient_checkpointing] [--lr LR] + [--log {none,wandb,tensorboard}] [--text-column TEXT_COLUMN] [--target-column TARGET_COLUMN] + [--max-seq-length MAX_SEQ_LENGTH] [--warmup-ratio WARMUP_RATIO] [--optimizer OPTIMIZER] + [--scheduler SCHEDULER] [--weight-decay WEIGHT_DECAY] [--max-grad-norm MAX_GRAD_NORM] + [--logging-steps LOGGING_STEPS] [--evaluation-strategy {steps,epoch,no}] + [--save-total-limit SAVE_TOTAL_LIMIT] + [--auto-find-batch-size] [--mixed-precision {fp16,bf16,None}] + +✨ Run AutoTrain Text Regression + +options: + -h, --help show this help message and exit + --train Command to train the model + --deploy Command to deploy the model (limited availability) + --inference Command to run inference (limited availability) + --username USERNAME Hugging Face Hub Username + --backend {local-cli,spaces-a10gl,spaces-a10gs,spaces-a100,spaces-t4m,spaces-t4s,spaces-cpu,spaces-cpuf} + Backend to use: default or spaces. Spaces backend requires push_to_hub & username. Advanced users only. + --token TOKEN Your Hugging Face API token. Token must have write access to the model hub. + --push-to-hub Push to hub after training will push the trained model to the Hugging Face model hub. + --model MODEL Base model to use for training + --project-name PROJECT_NAME + Output directory / repo id for trained model (must be unique on hub) + --data-path DATA_PATH + Train dataset to use. When using cli, this should be a directory path containing training and validation data in appropriate + formats + --train-split TRAIN_SPLIT + Train dataset split to use + --valid-split VALID_SPLIT + Validation dataset split to use + --batch-size BATCH_SIZE + Training batch size to use + --seed SEED Random seed for reproducibility + --epochs EPOCHS Number of training epochs + --gradient_accumulation GRADIENT_ACCUMULATION + Gradient accumulation steps + --disable_gradient_checkpointing + Disable gradient checkpointing + --lr LR Learning rate + --log {none,wandb,tensorboard} + Use experiment tracking + --text-column TEXT_COLUMN + Specify the column name in the dataset that contains the text data. Useful for distinguishing between multiple text fields. + Default is 'text'. + --target-column TARGET_COLUMN + Specify the column name that holds the target or label data for training. Helps in distinguishing different potential + outputs. Default is 'target'. + --max-seq-length MAX_SEQ_LENGTH + Set the maximum sequence length (number of tokens) that the model should handle in a single input. Longer sequences are + truncated. Affects both memory usage and computational requirements. Default is 128 tokens. + --warmup-ratio WARMUP_RATIO + Define the proportion of training to be dedicated to a linear warmup where learning rate gradually increases. This can help + in stabilizing the training process early on. Default ratio is 0.1. + --optimizer OPTIMIZER + Choose the optimizer algorithm for training the model. Different optimizers can affect the training speed and model + performance. 'adamw_torch' is used by default. + --scheduler SCHEDULER + Select the learning rate scheduler to adjust the learning rate based on the number of epochs. 'linear' decreases the + learning rate linearly from the initial lr set. Default is 'linear'. Try 'cosine' for a cosine annealing schedule. + --weight-decay WEIGHT_DECAY + Set the weight decay rate to apply for regularization. Helps in preventing the model from overfitting by penalizing large + weights. Default is 0.0, meaning no weight decay is applied. + --max-grad-norm MAX_GRAD_NORM + Specify the maximum norm of the gradients for gradient clipping. Gradient clipping is used to prevent the exploding gradient + problem in deep neural networks. Default is 1.0. + --logging-steps LOGGING_STEPS + Determine how often to log training progress. Set this to the number of steps between each log output. -1 determines logging + steps automatically. Default is -1. + --evaluation-strategy {steps,epoch,no} + Specify how often to evaluate the model performance. Options include 'no', 'steps', 'epoch'. 'epoch' evaluates at the end of + each training epoch by default. + --save-total-limit SAVE_TOTAL_LIMIT + Limit the total number of model checkpoints to save. Helps manage disk space by retaining only the most recent checkpoints. + Default is to save only the latest one. + --auto-find-batch-size + Enable automatic batch size determination based on your hardware capabilities. When set, it tries to find the largest batch + size that fits in memory. + --mixed-precision {fp16,bf16,None} + Choose the precision mode for training to optimize performance and memory usage. Options are 'fp16', 'bf16', or None for + default precision. Default is None. +``` \ No newline at end of file diff --git a/src/autotrain/app.py b/src/autotrain/app.py index 168e53cfd5..07576d9dea 100644 --- a/src/autotrain/app.py +++ b/src/autotrain/app.py @@ -25,6 +25,7 @@ from autotrain.trainers.seq2seq.params import Seq2SeqParams from autotrain.trainers.tabular.params import TabularParams from autotrain.trainers.text_classification.params import TextClassificationParams +from autotrain.trainers.text_regression.params import TextRegressionParams from autotrain.trainers.token_classification.params import TokenClassificationParams @@ -135,6 +136,10 @@ mixed_precision="fp16", log="tensorboard", ).model_dump() +PARAMS["text-regression"] = TextRegressionParams( + mixed_precision="fp16", + log="tensorboard", +).model_dump() MODEL_CHOICE = fetch_models() @@ -281,6 +286,18 @@ async def fetch_params(task: str, param_type: str): "evaluation_strategy", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} + if task == "text-regression" and param_type == "basic": + more_hidden_params = [ + "warmup_ratio", + "weight_decay", + "max_grad_norm", + "seed", + "logging_steps", + "auto_find_batch_size", + "save_total_limit", + "evaluation_strategy", + ] + task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} if task == "image-classification" and param_type == "basic": more_hidden_params = [ "warmup_ratio", @@ -394,6 +411,8 @@ async def fetch_model_choices(task: str, custom_models: str = Query(None)): hub_models = MODEL_CHOICE["tabular-regression"] elif task == "token-classification": hub_models = MODEL_CHOICE["token-classification"] + elif task == "text-regression": + hub_models = MODEL_CHOICE["text-regression"] else: raise NotImplementedError @@ -514,6 +533,8 @@ async def handle_form( dset_task = "lm_training" elif task == "text-classification": dset_task = "text_multi_class_classification" + elif task == "text-regression": + dset_task = "text_single_column_regression" elif task == "seq2seq": dset_task = "seq2seq" elif task.startswith("tabular"): diff --git a/src/autotrain/app_params.py b/src/autotrain/app_params.py index 53cfd85b27..7514dabc32 100644 --- a/src/autotrain/app_params.py +++ b/src/autotrain/app_params.py @@ -8,6 +8,7 @@ from autotrain.trainers.seq2seq.params import Seq2SeqParams from autotrain.trainers.tabular.params import TabularParams from autotrain.trainers.text_classification.params import TextClassificationParams +from autotrain.trainers.text_regression.params import TextRegressionParams from autotrain.trainers.token_classification.params import TokenClassificationParams @@ -44,6 +45,8 @@ def munge(self): return self._munge_params_llm() elif self.task == "token-classification": return self._munge_params_token_clf() + elif self.task == "text-regression": + return self._munge_params_text_reg() else: raise ValueError(f"Unknown task: {self.task}") @@ -91,6 +94,21 @@ def _munge_params_text_clf(self): _params["valid_split"] = self.valid_split return TextClassificationParams(**_params) + def _munge_params_text_reg(self): + _params = self._munge_common_params() + _params["model"] = self.base_model + _params["log"] = "tensorboard" + if not self.using_hub_dataset: + _params["text_column"] = "autotrain_text" + _params["target_column"] = "autotrain_label" + _params["valid_split"] = "validation" + else: + _params["text_column"] = self.column_mapping.get("text", "text") + _params["target_column"] = self.column_mapping.get("label", "label") + _params["train_split"] = self.train_split + _params["valid_split"] = self.valid_split + return TextRegressionParams(**_params) + def _munge_params_token_clf(self): _params = self._munge_common_params() _params["model"] = self.base_model diff --git a/src/autotrain/app_utils.py b/src/autotrain/app_utils.py index bbcd465e81..7742f26cbe 100644 --- a/src/autotrain/app_utils.py +++ b/src/autotrain/app_utils.py @@ -15,6 +15,7 @@ from autotrain.trainers.seq2seq.params import Seq2SeqParams from autotrain.trainers.tabular.params import TabularParams from autotrain.trainers.text_classification.params import TextClassificationParams +from autotrain.trainers.text_regression.params import TextRegressionParams from autotrain.trainers.token_classification.params import TokenClassificationParams @@ -124,6 +125,8 @@ def run_training(params, task_id, local=False, wait=False): params = ImageClassificationParams(**params) elif task_id == 4: params = TokenClassificationParams(**params) + elif task_id == 10: + params = TextRegressionParams(**params) else: raise NotImplementedError diff --git a/src/autotrain/backend.py b/src/autotrain/backend.py index b4224643fa..ff4c8c47cc 100644 --- a/src/autotrain/backend.py +++ b/src/autotrain/backend.py @@ -21,6 +21,7 @@ from autotrain.trainers.seq2seq.params import Seq2SeqParams from autotrain.trainers.tabular.params import TabularParams from autotrain.trainers.text_classification.params import TextClassificationParams +from autotrain.trainers.text_regression.params import TextRegressionParams from autotrain.trainers.token_classification.params import TokenClassificationParams @@ -45,6 +46,7 @@ class SpaceRunner: DreamBoothTrainingParams, Seq2SeqParams, TokenClassificationParams, + TextRegressionParams, ] backend: str @@ -96,6 +98,8 @@ def __post_init__(self): self.task_id = 18 elif isinstance(self.params, TokenClassificationParams): self.task_id = 4 + elif isinstance(self.params, TextRegressionParams): + self.task_id = 10 else: raise NotImplementedError @@ -132,6 +136,10 @@ def prepare(self): self.task_id = 4 space_id = self._create_space() return space_id + if isinstance(self.params, TextRegressionParams): + self.task_id = 10 + space_id = self._create_space() + return space_id raise NotImplementedError def _create_readme(self): diff --git a/src/autotrain/cli/autotrain.py b/src/autotrain/cli/autotrain.py index 25b016f57d..0efec056de 100644 --- a/src/autotrain/cli/autotrain.py +++ b/src/autotrain/cli/autotrain.py @@ -11,6 +11,7 @@ from .run_spacerunner import RunAutoTrainSpaceRunnerCommand from .run_tabular import RunAutoTrainTabularCommand from .run_text_classification import RunAutoTrainTextClassificationCommand +from .run_text_regression import RunAutoTrainTextRegressionCommand from .run_token_classification import RunAutoTrainTokenClassificationCommand from .run_tools import RunAutoTrainToolsCommand @@ -37,6 +38,7 @@ def main(): RunAutoTrainSeq2SeqCommand.register_subcommand(commands_parser) RunAutoTrainTokenClassificationCommand.register_subcommand(commands_parser) RunAutoTrainToolsCommand.register_subcommand(commands_parser) + RunAutoTrainTextRegressionCommand.register_subcommand(commands_parser) args = parser.parse_args() diff --git a/src/autotrain/cli/run_text_regression.py b/src/autotrain/cli/run_text_regression.py new file mode 100644 index 0000000000..5dfeb78a9f --- /dev/null +++ b/src/autotrain/cli/run_text_regression.py @@ -0,0 +1,170 @@ +from argparse import ArgumentParser + +from autotrain import logger +from autotrain.cli.utils import common_args, text_reg_munge_data +from autotrain.project import AutoTrainProject +from autotrain.trainers.text_regression.params import TextRegressionParams + +from . import BaseAutoTrainCommand + + +def run_text_regression_command_factory(args): + return RunAutoTrainTextRegressionCommand(args) + + +class RunAutoTrainTextRegressionCommand(BaseAutoTrainCommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + arg_list = [ + { + "arg": "--text-column", + "help": "Specify the column name in the dataset that contains the text data. Useful for distinguishing between multiple text fields. Default is 'text'.", + "required": False, + "type": str, + "default": "text", + }, + { + "arg": "--target-column", + "help": "Specify the column name that holds the target or label data for training. Helps in distinguishing different potential outputs. Default is 'target'.", + "required": False, + "type": str, + "default": "target", + }, + { + "arg": "--max-seq-length", + "help": "Set the maximum sequence length (number of tokens) that the model should handle in a single input. Longer sequences are truncated. Affects both memory usage and computational requirements. Default is 128 tokens.", + "required": False, + "type": int, + "default": 128, + }, + { + "arg": "--warmup-ratio", + "help": "Define the proportion of training to be dedicated to a linear warmup where learning rate gradually increases. This can help in stabilizing the training process early on. Default ratio is 0.1.", + "required": False, + "type": float, + "default": 0.1, + }, + { + "arg": "--optimizer", + "help": "Choose the optimizer algorithm for training the model. Different optimizers can affect the training speed and model performance. 'adamw_torch' is used by default.", + "required": False, + "type": str, + "default": "adamw_torch", + }, + { + "arg": "--scheduler", + "help": "Select the learning rate scheduler to adjust the learning rate based on the number of epochs. 'linear' decreases the learning rate linearly from the initial lr set. Default is 'linear'. Try 'cosine' for a cosine annealing schedule.", + "required": False, + "type": str, + "default": "linear", + }, + { + "arg": "--weight-decay", + "help": "Set the weight decay rate to apply for regularization. Helps in preventing the model from overfitting by penalizing large weights. Default is 0.0, meaning no weight decay is applied.", + "required": False, + "type": float, + "default": 0.0, + }, + { + "arg": "--max-grad-norm", + "help": "Specify the maximum norm of the gradients for gradient clipping. Gradient clipping is used to prevent the exploding gradient problem in deep neural networks. Default is 1.0.", + "required": False, + "type": float, + "default": 1.0, + }, + { + "arg": "--logging-steps", + "help": "Determine how often to log training progress. Set this to the number of steps between each log output. -1 determines logging steps automatically. Default is -1.", + "required": False, + "type": int, + "default": -1, + }, + { + "arg": "--evaluation-strategy", + "help": "Specify how often to evaluate the model performance. Options include 'no', 'steps', 'epoch'. 'epoch' evaluates at the end of each training epoch by default.", + "required": False, + "type": str, + "default": "epoch", + "choices": ["steps", "epoch", "no"], + }, + { + "arg": "--save-total-limit", + "help": "Limit the total number of model checkpoints to save. Helps manage disk space by retaining only the most recent checkpoints. Default is to save only the latest one.", + "required": False, + "type": int, + "default": 1, + }, + { + "arg": "--auto-find-batch-size", + "help": "Enable automatic batch size determination based on your hardware capabilities. When set, it tries to find the largest batch size that fits in memory.", + "required": False, + "action": "store_true", + }, + { + "arg": "--mixed-precision", + "help": "Choose the precision mode for training to optimize performance and memory usage. Options are 'fp16', 'bf16', or None for default precision. Default is None.", + "required": False, + "type": str, + "default": None, + "choices": ["fp16", "bf16", None], + }, + ] + arg_list = common_args() + arg_list + arg_list = [arg for arg in arg_list if arg["arg"] != "--disable-gradient-checkpointing"] + run_text_regression_parser = parser.add_parser( + "text-regression", description="✨ Run AutoTrain Text Regression" + ) + for arg in arg_list: + if "action" in arg: + run_text_regression_parser.add_argument( + arg["arg"], + help=arg["help"], + required=arg.get("required", False), + action=arg.get("action"), + default=arg.get("default"), + ) + else: + run_text_regression_parser.add_argument( + arg["arg"], + help=arg["help"], + required=arg.get("required", False), + type=arg.get("type"), + default=arg.get("default"), + choices=arg.get("choices"), + ) + run_text_regression_parser.set_defaults(func=run_text_regression_command_factory) + + def __init__(self, args): + self.args = args + + store_true_arg_names = [ + "train", + "deploy", + "inference", + "auto_find_batch_size", + "push_to_hub", + ] + for arg_name in store_true_arg_names: + if getattr(self.args, arg_name) is None: + setattr(self.args, arg_name, False) + + if self.args.train: + if self.args.project_name is None: + raise ValueError("Project name must be specified") + if self.args.data_path is None: + raise ValueError("Data path must be specified") + if self.args.model is None: + raise ValueError("Model must be specified") + if self.args.push_to_hub: + if self.args.username is None: + raise ValueError("Username must be specified for push to hub") + else: + raise ValueError("Must specify --train, --deploy or --inference") + + def run(self): + logger.info("Running Text Regression") + if self.args.train: + params = TextRegressionParams(**vars(self.args)) + params = text_reg_munge_data(params, local=self.args.backend.startswith("local")) + project = AutoTrainProject(params=params, backend=self.args.backend) + _ = project.create() diff --git a/src/autotrain/cli/utils.py b/src/autotrain/cli/utils.py index 3930335f36..bb87372210 100644 --- a/src/autotrain/cli/utils.py +++ b/src/autotrain/cli/utils.py @@ -308,6 +308,41 @@ def text_clf_munge_data(params, local): return params +def text_reg_munge_data(params, local): + exts = ["csv", "jsonl"] + ext_to_use = None + for ext in exts: + path = f"{params.data_path}/{params.train_split}.{ext}" + if os.path.exists(path): + ext_to_use = ext + break + + train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}" + if params.valid_split is not None: + valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}" + else: + valid_data_path = None + if os.path.exists(train_data_path): + dset = AutoTrainDataset( + train_data=[train_data_path], + valid_data=[valid_data_path] if valid_data_path is not None else None, + task="text_single_column_regression", + token=params.token, + project_name=params.project_name, + username=params.username, + column_mapping={"text": params.text_column, "label": params.target_column}, + percent_valid=None, # TODO: add to UI + local=local, + convert_to_class_label=False, + ext=ext_to_use, + ) + params.data_path = dset.prepare() + params.valid_split = "validation" + params.text_column = "autotrain_text" + params.target_column = "autotrain_label" + return params + + def token_clf_munge_data(params, local): exts = ["csv", "jsonl"] ext_to_use = None diff --git a/src/autotrain/commands.py b/src/autotrain/commands.py index 00af7f8e29..1b50627734 100644 --- a/src/autotrain/commands.py +++ b/src/autotrain/commands.py @@ -11,6 +11,7 @@ from autotrain.trainers.seq2seq.params import Seq2SeqParams from autotrain.trainers.tabular.params import TabularParams from autotrain.trainers.text_classification.params import TextClassificationParams +from autotrain.trainers.text_regression.params import TextRegressionParams from autotrain.trainers.token_classification.params import TokenClassificationParams @@ -124,7 +125,7 @@ def launch_command(params): "--training_config", os.path.join(params.project_name, "training_params.json"), ] - elif isinstance(params, TextClassificationParams): + elif isinstance(params, TextClassificationParams) or isinstance(params, TextRegressionParams): if num_gpus == 0: cmd = [ "accelerate", @@ -160,14 +161,24 @@ def launch_command(params): else: cmd.append("no") - cmd.extend( - [ - "-m", - "autotrain.trainers.text_classification", - "--training_config", - os.path.join(params.project_name, "training_params.json"), - ] - ) + if isinstance(params, TextRegressionParams): + cmd.extend( + [ + "-m", + "autotrain.trainers.text_regression", + "--training_config", + os.path.join(params.project_name, "training_params.json"), + ] + ) + else: + cmd.extend( + [ + "-m", + "autotrain.trainers.text_classification", + "--training_config", + os.path.join(params.project_name, "training_params.json"), + ] + ) elif isinstance(params, TokenClassificationParams): if num_gpus == 0: cmd = [ diff --git a/src/autotrain/models.py b/src/autotrain/models.py index 52abac96ea..54642de5d6 100644 --- a/src/autotrain/models.py +++ b/src/autotrain/models.py @@ -244,6 +244,7 @@ def fetch_models(): _mc["dreambooth"] = _fetch_dreambooth_models() _mc["seq2seq"] = _fetch_seq2seq_models() _mc["token-classification"] = _fetch_token_classification_models() + _mc["text-regression"] = _fetch_text_classification_models() # tabular-classification _mc["tabular-classification"] = [ diff --git a/src/autotrain/preprocessor/text.py b/src/autotrain/preprocessor/text.py index e299a24d08..f8a18aaf55 100644 --- a/src/autotrain/preprocessor/text.py +++ b/src/autotrain/preprocessor/text.py @@ -139,6 +139,39 @@ def split(self): valid_df = valid_df.reset_index(drop=True) return train_df, valid_df + def prepare(self): + train_df, valid_df = self.split() + train_df, valid_df = self.prepare_columns(train_df, valid_df) + + train_df = Dataset.from_pandas(train_df) + valid_df = Dataset.from_pandas(valid_df) + + if self.local: + dataset = DatasetDict( + { + "train": train_df, + "validation": valid_df, + } + ) + dataset.save_to_disk(f"{self.project_name}/autotrain-data") + else: + train_df.push_to_hub( + f"{self.username}/autotrain-data-{self.project_name}", + split="train", + private=True, + token=self.token, + ) + valid_df.push_to_hub( + f"{self.username}/autotrain-data-{self.project_name}", + split="validation", + private=True, + token=self.token, + ) + + if self.local: + return f"{self.project_name}/autotrain-data" + return f"{self.username}/autotrain-data-{self.project_name}" + class TextTokenClassificationPreprocessor(TextBinaryClassificationPreprocessor): def split(self): diff --git a/src/autotrain/project.py b/src/autotrain/project.py index e9910df930..029c4e7b40 100644 --- a/src/autotrain/project.py +++ b/src/autotrain/project.py @@ -12,6 +12,7 @@ from autotrain.trainers.seq2seq.params import Seq2SeqParams from autotrain.trainers.tabular.params import TabularParams from autotrain.trainers.text_classification.params import TextClassificationParams +from autotrain.trainers.text_regression.params import TextRegressionParams @dataclass @@ -25,6 +26,7 @@ class AutoTrainProject: DreamBoothTrainingParams, Seq2SeqParams, ImageClassificationParams, + TextRegressionParams, ] ], LLMTrainingParams, @@ -33,6 +35,7 @@ class AutoTrainProject: DreamBoothTrainingParams, Seq2SeqParams, ImageClassificationParams, + TextRegressionParams, ] backend: str diff --git a/src/autotrain/tasks.py b/src/autotrain/tasks.py index 43e9fb7e9d..16f5fb8592 100644 --- a/src/autotrain/tasks.py +++ b/src/autotrain/tasks.py @@ -32,43 +32,3 @@ **VISION_TASKS, **TABULAR_TASKS, } - -COLUMN_MAPPING = { - "text_binary_classification": ("text", "label"), - "text_multi_class_classification": ("text", "label"), - "text_token_classification": ("tokens", "tags"), - "text_extractive_question_answering": ("text", "context", "question", "answer"), - "text_summarization": ("text", "summary"), - "text_single_column_regression": ("text", "label"), - "speech_recognition": ("audio", "text"), - "natural_language_inference": ("premise", "hypothesis", "label"), - "image_binary_classification": ("image", "label"), - "image_multi_class_classification": ("image", "label"), - "image_single_column_regression": ("image", "label"), - # "dreambooth": ("image", "label"), - "tabular_binary_classification": ("id", "label"), - "tabular_multi_class_classification": ("id", "label"), - "tabular_multi_label_classification": ("id", "label"), - "tabular_single_column_regression": ("id", "label"), - "lm_training": ("text", "prompt_start", "prompt", "context", "response"), -} - -TASK_TYPE_MAPPING = { - "text_binary_classification": "Natural Language Processing", - "text_multi_class_classification": "Natural Language Processing", - "text_token_classification": "Natural Language Processing", - "text_extractive_question_answering": "Natural Language Processing", - "text_summarization": "Natural Language Processing", - "text_single_column_regression": "Natural Language Processing", - "lm_training": "Natural Language Processing", - "speech_recognition": "Natural Language Processing", - "natural_language_inference": "Natural Language Processing", - "image_binary_classification": "Computer Vision", - "image_multi_class_classification": "Computer Vision", - "image_single_column_regression": "Computer Vision", - "dreambooth": "Computer Vision", - "tabular_binary_classification": "Tabular", - "tabular_multi_class_classification": "Tabular", - "tabular_multi_label_classification": "Tabular", - "tabular_single_column_regression": "Tabular", -} diff --git a/src/autotrain/templates/index.html b/src/autotrain/templates/index.html index e2c0290bd1..c566261f33 100644 --- a/src/autotrain/templates/index.html +++ b/src/autotrain/templates/index.html @@ -303,6 +303,7 @@ + @@ -741,6 +742,11 @@

document.getElementById("hub-dataset-radio").disabled = false; document.getElementById("valid_split").disabled = false; break; + case 'text-regression': + placeholderText = '{"text": "text", "label": "target"}'; + document.getElementById("hub-dataset-radio").disabled = false; + document.getElementById("valid_split").disabled = false; + break; case 'token-classification': placeholderText = '{"text": "tokens", "label": "tags"}'; document.getElementById("hub-dataset-radio").disabled = false; diff --git a/src/autotrain/trainers/text_regression/__init__.py b/src/autotrain/trainers/text_regression/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/autotrain/trainers/text_regression/__main__.py b/src/autotrain/trainers/text_regression/__main__.py new file mode 100644 index 0000000000..a8bc0212a7 --- /dev/null +++ b/src/autotrain/trainers/text_regression/__main__.py @@ -0,0 +1,199 @@ +import argparse +import json + +from accelerate.state import PartialState +from datasets import load_dataset, load_from_disk +from huggingface_hub import HfApi +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + EarlyStoppingCallback, + Trainer, + TrainingArguments, +) +from transformers.trainer_callback import PrinterCallback + +from autotrain import logger +from autotrain.trainers.common import ( + ALLOW_REMOTE_CODE, + LossLoggingCallback, + TrainStartCallback, + UploadLogs, + monitor, + pause_space, + remove_autotrain_data, + save_training_params, +) +from autotrain.trainers.text_regression import utils +from autotrain.trainers.text_regression.dataset import TextRegressionDataset +from autotrain.trainers.text_regression.params import TextRegressionParams + + +def parse_args(): + # get training_config.json from the end user + parser = argparse.ArgumentParser() + parser.add_argument("--training_config", type=str, required=True) + return parser.parse_args() + + +@monitor +def train(config): + if isinstance(config, dict): + config = TextRegressionParams(**config) + + train_data = None + valid_data = None + # check if config.train_split.csv exists in config.data_path + if config.train_split is not None: + if config.data_path == f"{config.project_name}/autotrain-data": + logger.info("loading dataset from disk") + train_data = load_from_disk(config.data_path)[config.train_split] + else: + train_data = load_dataset( + config.data_path, + split=config.train_split, + token=config.token, + ) + + if config.valid_split is not None: + if config.data_path == f"{config.project_name}/autotrain-data": + logger.info("loading dataset from disk") + valid_data = load_from_disk(config.data_path)[config.valid_split] + else: + valid_data = load_dataset( + config.data_path, + split=config.valid_split, + token=config.token, + ) + + model_config = AutoConfig.from_pretrained(config.model, num_labels=1) + model_config._num_labels = 1 + label2id = {"target": 0} + model_config.label2id = label2id + model_config.id2label = {v: k for k, v in label2id.items()} + + try: + model = AutoModelForSequenceClassification.from_pretrained( + config.model, + config=model_config, + trust_remote_code=ALLOW_REMOTE_CODE, + token=config.token, + ignore_mismatched_sizes=True, + ) + except OSError: + model = AutoModelForSequenceClassification.from_pretrained( + config.model, + config=model_config, + from_tf=True, + trust_remote_code=ALLOW_REMOTE_CODE, + token=config.token, + ignore_mismatched_sizes=True, + ) + + tokenizer = AutoTokenizer.from_pretrained(config.model, token=config.token, trust_remote_code=ALLOW_REMOTE_CODE) + train_data = TextRegressionDataset(data=train_data, tokenizer=tokenizer, config=config) + if config.valid_split is not None: + valid_data = TextRegressionDataset(data=valid_data, tokenizer=tokenizer, config=config) + + if config.logging_steps == -1: + if config.valid_split is not None: + logging_steps = int(0.2 * len(valid_data) / config.batch_size) + else: + logging_steps = int(0.2 * len(train_data) / config.batch_size) + if logging_steps == 0: + logging_steps = 1 + if logging_steps > 25: + logging_steps = 25 + config.logging_steps = logging_steps + else: + logging_steps = config.logging_steps + + logger.info(f"Logging steps: {logging_steps}") + + training_args = dict( + output_dir=config.project_name, + per_device_train_batch_size=config.batch_size, + per_device_eval_batch_size=2 * config.batch_size, + learning_rate=config.lr, + num_train_epochs=config.epochs, + evaluation_strategy=config.evaluation_strategy if config.valid_split is not None else "no", + logging_steps=logging_steps, + save_total_limit=config.save_total_limit, + save_strategy=config.evaluation_strategy if config.valid_split is not None else "no", + gradient_accumulation_steps=config.gradient_accumulation, + report_to=config.log, + auto_find_batch_size=config.auto_find_batch_size, + lr_scheduler_type=config.scheduler, + optim=config.optimizer, + warmup_ratio=config.warmup_ratio, + weight_decay=config.weight_decay, + max_grad_norm=config.max_grad_norm, + push_to_hub=False, + load_best_model_at_end=True if config.valid_split is not None else False, + ddp_find_unused_parameters=False, + ) + + if config.mixed_precision == "fp16": + training_args["fp16"] = True + if config.mixed_precision == "bf16": + training_args["bf16"] = True + + if config.valid_split is not None: + early_stop = EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01) + callbacks_to_use = [early_stop] + else: + callbacks_to_use = [] + + callbacks_to_use.extend([UploadLogs(config=config), LossLoggingCallback(), TrainStartCallback()]) + + args = TrainingArguments(**training_args) + trainer_args = dict( + args=args, + model=model, + callbacks=callbacks_to_use, + compute_metrics=utils.single_column_regression_metrics, + ) + + trainer = Trainer( + **trainer_args, + train_dataset=train_data, + eval_dataset=valid_data, + ) + trainer.remove_callback(PrinterCallback) + trainer.train() + + logger.info("Finished training, saving model...") + trainer.save_model(config.project_name) + tokenizer.save_pretrained(config.project_name) + + model_card = utils.create_model_card(config, trainer) + + # save model card to output directory as README.md + with open(f"{config.project_name}/README.md", "w") as f: + f.write(model_card) + + if config.push_to_hub: + if PartialState().process_index == 0: + remove_autotrain_data(config) + save_training_params(config) + logger.info("Pushing model to hub...") + api = HfApi(token=config.token) + api.create_repo( + repo_id=f"{config.username}/{config.project_name}", repo_type="model", private=True, exist_ok=True + ) + api.upload_folder( + folder_path=config.project_name, + repo_id=f"{config.username}/{config.project_name}", + repo_type="model", + ) + + if PartialState().process_index == 0: + pause_space(config) + + +if __name__ == "__main__": + args = parse_args() + training_config = json.load(open(args.training_config)) + config = TextRegressionParams(**training_config) + train(config) diff --git a/src/autotrain/trainers/text_regression/dataset.py b/src/autotrain/trainers/text_regression/dataset.py new file mode 100644 index 0000000000..6a682b3a7c --- /dev/null +++ b/src/autotrain/trainers/text_regression/dataset.py @@ -0,0 +1,45 @@ +import torch + + +class TextRegressionDataset: + def __init__(self, data, tokenizer, config): + self.data = data + self.tokenizer = tokenizer + self.config = config + self.text_column = self.config.text_column + self.target_column = self.config.target_column + self.max_len = self.config.max_seq_length + + def __len__(self): + return len(self.data) + + def __getitem__(self, item): + text = str(self.data[item][self.text_column]) + target = float(self.data[item][self.target_column]) + inputs = self.tokenizer( + text, + max_length=self.max_len, + padding="max_length", + truncation=True, + ) + + ids = inputs["input_ids"] + mask = inputs["attention_mask"] + + if "token_type_ids" in inputs: + token_type_ids = inputs["token_type_ids"] + else: + token_type_ids = None + + if token_type_ids is not None: + return { + "input_ids": torch.tensor(ids, dtype=torch.long), + "attention_mask": torch.tensor(mask, dtype=torch.long), + "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long), + "labels": torch.tensor(target, dtype=torch.float), + } + return { + "input_ids": torch.tensor(ids, dtype=torch.long), + "attention_mask": torch.tensor(mask, dtype=torch.long), + "labels": torch.tensor(target, dtype=torch.float), + } diff --git a/src/autotrain/trainers/text_regression/params.py b/src/autotrain/trainers/text_regression/params.py new file mode 100644 index 0000000000..6b2c5af0cb --- /dev/null +++ b/src/autotrain/trainers/text_regression/params.py @@ -0,0 +1,35 @@ +from typing import Optional + +from pydantic import Field + +from autotrain.trainers.common import AutoTrainParams + + +class TextRegressionParams(AutoTrainParams): + data_path: str = Field(None, title="Data path") + model: str = Field("bert-base-uncased", title="Model name") + lr: float = Field(5e-5, title="Learning rate") + epochs: int = Field(3, title="Number of training epochs") + max_seq_length: int = Field(128, title="Max sequence length") + batch_size: int = Field(8, title="Training batch size") + warmup_ratio: float = Field(0.1, title="Warmup proportion") + gradient_accumulation: int = Field(1, title="Gradient accumulation steps") + optimizer: str = Field("adamw_torch", title="Optimizer") + scheduler: str = Field("linear", title="Scheduler") + weight_decay: float = Field(0.0, title="Weight decay") + max_grad_norm: float = Field(1.0, title="Max gradient norm") + seed: int = Field(42, title="Seed") + train_split: str = Field("train", title="Train split") + valid_split: Optional[str] = Field(None, title="Validation split") + text_column: str = Field("text", title="Text column") + target_column: str = Field("target", title="Target column(s)") + logging_steps: int = Field(-1, title="Logging steps") + project_name: str = Field("project-name", title="Output directory") + auto_find_batch_size: bool = Field(False, title="Auto find batch size") + mixed_precision: Optional[str] = Field(None, title="fp16, bf16, or None") + save_total_limit: int = Field(1, title="Save total limit") + token: Optional[str] = Field(None, title="Hub Token") + push_to_hub: bool = Field(False, title="Push to hub") + evaluation_strategy: str = Field("epoch", title="Evaluation strategy") + username: Optional[str] = Field(None, title="Hugging Face Username") + log: str = Field("none", title="Logging using experiment tracking") diff --git a/src/autotrain/trainers/text_regression/utils.py b/src/autotrain/trainers/text_regression/utils.py new file mode 100644 index 0000000000..49adb9cd26 --- /dev/null +++ b/src/autotrain/trainers/text_regression/utils.py @@ -0,0 +1,80 @@ +import numpy as np +from sklearn import metrics + + +SINGLE_COLUMN_REGRESSION_EVAL_METRICS = ( + "eval_loss", + "eval_mse", + "eval_mae", + "eval_r2", + "eval_rmse", + "eval_explained_variance", +) + + +MODEL_CARD = """ +--- +tags: +- autotrain +- text-regression +widget: +- text: "I love AutoTrain" +datasets: +- {dataset} +--- + +# Model Trained Using AutoTrain + +- Problem type: Text Regression + +## Validation Metrics +{validation_metrics} +""" + + +def single_column_regression_metrics(pred): + raw_predictions, labels = pred + + # try: + # raw_predictions = [r for preds in raw_predictions for r in preds] + # except TypeError as err: + # if "numpy.float32" not in str(err): + # raise Exception(err) + + def safe_compute(metric_func, default=-999): + try: + return metric_func(labels, raw_predictions) + except Exception: + return default + + pred_dict = { + "mse": safe_compute(lambda labels, predictions: metrics.mean_squared_error(labels, predictions)), + "mae": safe_compute(lambda labels, predictions: metrics.mean_absolute_error(labels, predictions)), + "r2": safe_compute(lambda labels, predictions: metrics.r2_score(labels, predictions)), + "rmse": safe_compute(lambda labels, predictions: np.sqrt(metrics.mean_squared_error(labels, predictions))), + "explained_variance": safe_compute( + lambda labels, predictions: metrics.explained_variance_score(labels, predictions) + ), + } + + for key, value in pred_dict.items(): + pred_dict[key] = float(value) + return pred_dict + + +def create_model_card(config, trainer): + if config.valid_split is not None: + eval_scores = trainer.evaluate() + eval_scores = [ + f"{k[len('eval_'):]}: {v}" for k, v in eval_scores.items() if k in SINGLE_COLUMN_REGRESSION_EVAL_METRICS + ] + eval_scores = "\n\n".join(eval_scores) + + else: + eval_scores = "No validation metrics available" + + model_card = MODEL_CARD.format( + dataset=config.data_path, + validation_metrics=eval_scores, + ) + return model_card