From 86528ad60a5f94a4f4f4fb584324411a7df5bb05 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Fri, 20 Oct 2023 14:40:27 +0200 Subject: [PATCH] seq2seq --- src/autotrain/cli/autotrain.py | 2 + src/autotrain/cli/run_seq2seq.py | 440 +++++++++++++++++++ src/autotrain/cli/run_text_classification.py | 11 +- src/autotrain/trainers/seq2seq/__init__.py | 0 src/autotrain/trainers/seq2seq/__main__.py | 248 +++++++++++ src/autotrain/trainers/seq2seq/dataset.py | 21 + src/autotrain/trainers/seq2seq/params.py | 45 ++ src/autotrain/trainers/seq2seq/utils.py | 60 +++ 8 files changed, 825 insertions(+), 2 deletions(-) create mode 100644 src/autotrain/cli/run_seq2seq.py create mode 100644 src/autotrain/trainers/seq2seq/__init__.py create mode 100644 src/autotrain/trainers/seq2seq/__main__.py create mode 100644 src/autotrain/trainers/seq2seq/dataset.py create mode 100644 src/autotrain/trainers/seq2seq/params.py create mode 100644 src/autotrain/trainers/seq2seq/utils.py diff --git a/src/autotrain/cli/autotrain.py b/src/autotrain/cli/autotrain.py index 697ec64a1e..3514dfb9e0 100644 --- a/src/autotrain/cli/autotrain.py +++ b/src/autotrain/cli/autotrain.py @@ -6,6 +6,7 @@ from .run_dreambooth import RunAutoTrainDreamboothCommand from .run_image_classification import RunAutoTrainImageClassificationCommand from .run_llm import RunAutoTrainLLMCommand +from .run_seq2seq import RunAutoTrainSeq2SeqCommand from .run_setup import RunSetupCommand from .run_spacerunner import RunAutoTrainSpaceRunnerCommand from .run_tabular import RunAutoTrainTabularCommand @@ -31,6 +32,7 @@ def main(): RunAutoTrainImageClassificationCommand.register_subcommand(commands_parser) RunAutoTrainTabularCommand.register_subcommand(commands_parser) RunAutoTrainSpaceRunnerCommand.register_subcommand(commands_parser) + RunAutoTrainSeq2SeqCommand.register_subcommand(commands_parser) args = parser.parse_args() diff --git a/src/autotrain/cli/run_seq2seq.py b/src/autotrain/cli/run_seq2seq.py new file mode 100644 index 0000000000..de7895ae26 --- /dev/null +++ b/src/autotrain/cli/run_seq2seq.py @@ -0,0 +1,440 @@ +import os +import subprocess +import sys +from argparse import ArgumentParser + +import torch + +from autotrain import logger +from autotrain.backend import EndpointsRunner, SpaceRunner + +from . import BaseAutoTrainCommand + + +def run_seq2seq_command_factory(args): + return RunAutoTrainSeq2SeqCommand(args) + + +class RunAutoTrainSeq2SeqCommand(BaseAutoTrainCommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + arg_list = [ + { + "arg": "--train", + "help": "Train the model", + "required": False, + "action": "store_true", + }, + { + "arg": "--deploy", + "help": "Deploy the model", + "required": False, + "action": "store_true", + }, + { + "arg": "--inference", + "help": "Run inference", + "required": False, + "action": "store_true", + }, + { + "arg": "--data-path", + "help": "Train dataset to use", + "required": False, + "type": str, + }, + { + "arg": "--train-split", + "help": "Test dataset split to use", + "required": False, + "type": str, + "default": "train", + }, + { + "arg": "--valid-split", + "help": "Validation dataset split to use", + "required": False, + "type": str, + "default": None, + }, + { + "arg": "--text-column", + "help": "Text column to use", + "required": False, + "type": str, + "default": "text", + }, + { + "arg": "--target-column", + "help": "Target column to use", + "required": False, + "type": str, + "default": "target", + }, + { + "arg": "--model", + "help": "Model to use", + "required": False, + "type": str, + }, + { + "arg": "--lr", + "help": "Learning rate to use", + "required": False, + "type": float, + "default": 3e-5, + }, + { + "arg": "--epochs", + "help": "Number of training epochs to use", + "required": False, + "type": int, + "default": 1, + }, + { + "arg": "--max-seq-length", + "help": "Maximum number of tokens in a sequence to use", + "required": False, + "type": int, + "default": 128, + }, + { + "arg": "--max-target-length", + "help": "Maximum number of tokens in a target sequence to use", + "required": False, + "type": int, + "default": 128, + }, + { + "arg": "--batch-size", + "help": "Training batch size to use", + "required": False, + "type": int, + "default": 2, + }, + { + "arg": "--warmup-ratio", + "help": "Warmup proportion to use", + "required": False, + "type": float, + "default": 0.1, + }, + { + "arg": "--gradient-accumulation", + "help": "Gradient accumulation steps to use", + "required": False, + "type": int, + "default": 1, + }, + { + "arg": "--optimizer", + "help": "Optimizer to use", + "required": False, + "type": str, + "default": "adamw_torch", + }, + { + "arg": "--scheduler", + "help": "Scheduler to use", + "required": False, + "type": str, + "default": "linear", + }, + { + "arg": "--weight-decay", + "help": "Weight decay to use", + "required": False, + "type": float, + "default": 0.0, + }, + { + "arg": "--max-grad-norm", + "help": "Max gradient norm to use", + "required": False, + "type": float, + "default": 1.0, + }, + { + "arg": "--seed", + "help": "Seed to use", + "required": False, + "type": int, + "default": 42, + }, + { + "arg": "--logging-steps", + "help": "Logging steps to use", + "required": False, + "type": int, + "default": -1, + }, + { + "arg": "--project-name", + "help": "Output directory", + "required": False, + "type": str, + }, + { + "arg": "--evaluation-strategy", + "help": "Evaluation strategy to use", + "required": False, + "type": str, + "default": "epoch", + }, + { + "arg": "--save-total-limit", + "help": "Save total limit to use", + "required": False, + "type": int, + "default": 1, + }, + { + "arg": "--save-strategy", + "help": "Save strategy to use", + "required": False, + "type": str, + "default": "epoch", + }, + { + "arg": "--auto-find-batch-size", + "help": "Auto find batch size True/False", + "required": False, + "action": "store_true", + }, + { + "arg": "--fp16", + "help": "FP16 True/False", + "required": False, + "action": "store_true", + }, + { + "arg": "--token", + "help": "Hugging face token", + "required": False, + "type": str, + "default": "", + }, + { + "arg": "--push-to-hub", + "help": "Push to hub True/False. In case you want to push the trained model to huggingface hub", + "required": False, + "action": "store_true", + }, + { + "arg": "--repo-id", + "help": "Repo id for hugging face hub", + "required": False, + "type": str, + }, + { + "arg": "--backend", + "help": "Backend to use: default or spaces. Spaces backend requires push_to_hub and repo_id", + "required": False, + "type": str, + "default": "default", + }, + { + "arg": "--username", + "help": "Huggingface username to use", + "required": False, + "type": str, + }, + { + "arg": "--use-peft", + "help": "Use PEFT", + "required": False, + "action": "store_true", + }, + { + "arg": "--use-int8", + "help": "Use INT8", + "required": False, + "action": "store_true", + }, + { + "arg": "--lora-r", + "help": "LoRA-R", + "required": False, + "type": int, + "default": 16, + }, + { + "arg": "--lora-alpha", + "help": "LoRA-Alpha", + "required": False, + "type": int, + "default": 32, + }, + { + "arg": "--lora-dropout", + "help": "LoRA-Dropout", + "required": False, + "type": float, + "default": 0.05, + }, + { + "arg": "--target-modules", + "help": "Target modules", + "required": False, + "type": str, + "default": "", + }, + ] + run_seq2seq_parser = parser.add_parser("seq2seq", description="✨ Run AutoTrain Seq2Seq") + for arg in arg_list: + if "action" in arg: + run_seq2seq_parser.add_argument( + arg["arg"], + help=arg["help"], + required=arg.get("required", False), + action=arg.get("action"), + default=arg.get("default"), + ) + else: + run_seq2seq_parser.add_argument( + arg["arg"], + help=arg["help"], + required=arg.get("required", False), + type=arg.get("type"), + default=arg.get("default"), + ) + run_seq2seq_parser.set_defaults(func=run_seq2seq_command_factory) + + def __init__(self, args): + self.args = args + + store_true_arg_names = [ + "train", + "deploy", + "inference", + "auto_find_batch_size", + "fp16", + "push_to_hub", + "use_peft", + "use_int8", + ] + for arg_name in store_true_arg_names: + if getattr(self.args, arg_name) is None: + setattr(self.args, arg_name, False) + + if self.args.train: + if self.args.project_name is None: + raise ValueError("Project name must be specified") + if self.args.data_path is None: + raise ValueError("Data path must be specified") + if self.args.model is None: + raise ValueError("Model must be specified") + if self.args.push_to_hub: + if self.args.repo_id is None: + raise ValueError("Repo id must be specified for push to hub") + else: + raise ValueError("Must specify --train, --deploy or --inference") + + if not torch.cuda.is_available(): + self.device = "cpu" + + self.num_gpus = torch.cuda.device_count() + + if len(str(self.args.token)) < 6: + self.args.token = os.environ.get("HF_TOKEN", None) + + if len(self.args.target_modules.strip()) == 0: + self.args.target_modules = [] + else: + self.args.target_modules = self.args.target_modules.split(",") + + def run(self): + from autotrain.trainers.seq2seq.__main__ import train as train_seq2seq + from autotrain.trainers.seq2seq.params import Seq2SeqParams + + logger.info("Running Seq2Seq Classification") + if self.args.train: + params = Seq2SeqParams( + data_path=self.args.data_path, + train_split=self.args.train_split, + valid_split=self.args.valid_split, + text_column=self.args.text_column, + target_column=self.args.target_column, + model=self.args.model, + lr=self.args.lr, + epochs=self.args.epochs, + max_seq_length=self.args.max_seq_length, + max_target_length=self.args.max_target_length, + batch_size=self.args.batch_size, + warmup_ratio=self.args.warmup_ratio, + gradient_accumulation=self.args.gradient_accumulation, + optimizer=self.args.optimizer, + scheduler=self.args.scheduler, + weight_decay=self.args.weight_decay, + max_grad_norm=self.args.max_grad_norm, + seed=self.args.seed, + logging_steps=self.args.logging_steps, + project_name=self.args.project_name, + evaluation_strategy=self.args.evaluation_strategy, + save_total_limit=self.args.save_total_limit, + save_strategy=self.args.save_strategy, + auto_find_batch_size=self.args.auto_find_batch_size, + fp16=self.args.fp16, + push_to_hub=self.args.push_to_hub, + repo_id=self.args.repo_id, + token=self.args.token, + username=self.args.username, + use_peft=self.args.use_peft, + use_int8=self.args.use_int8, + lora_r=self.args.lora_r, + lora_alpha=self.args.lora_alpha, + lora_dropout=self.args.lora_dropout, + target_modules=self.args.target_modules, + ) + + if self.args.backend.startswith("spaces"): + logger.info("Creating space...") + sr = SpaceRunner( + params=params, + backend=self.args.backend, + ) + space_id = sr.prepare() + logger.info(f"Training Space created. Check progress at https://hf.co/spaces/{space_id}") + sys.exit(0) + + if self.args.backend.startswith("ep-"): + logger.info("Creating training endpoint...") + sr = EndpointsRunner( + params=params, + backend=self.args.backend, + ) + sr.prepare() + logger.info("Training endpoint created.") + sys.exit(0) + + params.save(output_dir=self.args.project_name) + if self.num_gpus == 1: + train_seq2seq(params) + else: + cmd = [ + "accelerate", + "launch", + "--multi_gpu", + "--num_machines", + "1", + "--num_processes", + ] + cmd.append(str(self.num_gpus)) + cmd.append("--mixed_precision") + if self.args.fp16: + cmd.append("fp16") + else: + cmd.append("no") + + cmd.extend( + [ + "-m", + "autotrain.trainers.seq2seq", + "--training_config", + os.path.join(self.args.project_name, "training_params.json"), + ] + ) + + env = os.environ.copy() + process = subprocess.Popen(cmd, env=env) + process.wait() diff --git a/src/autotrain/cli/run_text_classification.py b/src/autotrain/cli/run_text_classification.py index cdb92d6dbb..992692ea49 100644 --- a/src/autotrain/cli/run_text_classification.py +++ b/src/autotrain/cli/run_text_classification.py @@ -303,7 +303,7 @@ def run(self): valid_split=self.args.valid_split, text_column=self.args.text_column, target_column=self.args.target_column, - model_name=self.args.model, + model=self.args.model, lr=self.args.lr, epochs=self.args.epochs, max_seq_length=self.args.max_seq_length, @@ -352,7 +352,14 @@ def run(self): if self.num_gpus == 1: train_text_classification(params) else: - cmd = ["accelerate", "launch", "--multi_gpu", "--num_machines", "1", "--num_processes"] + cmd = [ + "accelerate", + "launch", + "--multi_gpu", + "--num_machines", + "1", + "--num_processes", + ] cmd.append(str(self.num_gpus)) cmd.append("--mixed_precision") if self.args.fp16: diff --git a/src/autotrain/trainers/seq2seq/__init__.py b/src/autotrain/trainers/seq2seq/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/autotrain/trainers/seq2seq/__main__.py b/src/autotrain/trainers/seq2seq/__main__.py new file mode 100644 index 0000000000..01a08f6eb3 --- /dev/null +++ b/src/autotrain/trainers/seq2seq/__main__.py @@ -0,0 +1,248 @@ +import argparse +import json +import os +from functools import partial + +import pandas as pd +import torch +from accelerate import Accelerator +from accelerate.state import PartialState +from datasets import Dataset, load_dataset +from huggingface_hub import HfApi +from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_int8_training +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + BitsAndBytesConfig, + DataCollatorForSeq2Seq, + EarlyStoppingCallback, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, +) + +from autotrain import logger +from autotrain.trainers.seq2seq import utils +from autotrain.trainers.seq2seq.dataset import Seq2SeqDataset +from autotrain.trainers.seq2seq.params import Seq2SeqParams +from autotrain.utils import monitor + + +def parse_args(): + # get training_config.json from the end user + parser = argparse.ArgumentParser() + parser.add_argument("--training_config", type=str, required=True) + return parser.parse_args() + + +@monitor +def train(config): + if isinstance(config, dict): + config = Seq2SeqParams(**config) + + if config.repo_id is None and config.username is not None: + config.repo_id = f"{config.username}/{config.project_name}" + + if PartialState().process_index == 0: + logger.info("Starting training...") + logger.info(f"Training config: {config}") + + train_data = None + valid_data = None + # check if config.train_split.csv exists in config.data_path + if config.train_split is not None: + train_path = f"{config.data_path}/{config.train_split}.csv" + if os.path.exists(train_path): + logger.info("loading dataset from csv") + train_data = pd.read_csv(train_path) + train_data = Dataset.from_pandas(train_data) + else: + train_data = load_dataset( + config.data_path, + split=config.train_split, + token=config.token, + ) + + if config.valid_split is not None: + valid_path = f"{config.data_path}/{config.valid_split}.csv" + if os.path.exists(valid_path): + logger.info("loading dataset from csv") + valid_data = pd.read_csv(valid_path) + valid_data = Dataset.from_pandas(valid_data) + else: + valid_data = load_dataset( + config.data_path, + split=config.valid_split, + token=config.token, + ) + + model_config = AutoConfig.from_pretrained(config.model, token=config.token, trust_remote_code=True) + + if config.use_peft: + # if config.use_int4: + # bnb_config = BitsAndBytesConfig( + # load_in_4bit=config.use_int4, + # bnb_4bit_quant_type="nf4", + # bnb_4bit_compute_dtype=torch.float16, + # bnb_4bit_use_double_quant=False, + # ) + # config.fp16 = True + if config.use_int8: + bnb_config = BitsAndBytesConfig(load_in_8bit=config.use_int8) + config.fp16 = True + else: + bnb_config = None + + model = AutoModelForSeq2SeqLM.from_pretrained( + config.model, + config=model_config, + token=config.token, + quantization_config=bnb_config, + torch_dtype=torch.float16, + device_map={"": Accelerator().process_index} if torch.cuda.is_available() else None, + trust_remote_code=True, + ) + else: + model = AutoModelForSeq2SeqLM.from_pretrained( + config.model, + config=model_config, + token=config.token, + trust_remote_code=True, + ) + + tokenizer = AutoTokenizer.from_pretrained(config.model, token=config.token, trust_remote_code=True) + + embedding_size = model.get_input_embeddings().weight.shape[0] + if len(tokenizer) > embedding_size: + model.resize_token_embeddings(len(tokenizer)) + + if config.use_peft: + lora_config = LoraConfig( + r=config.lora_r, + lora_alpha=config.lora_alpha, + target_modules=None if len(config.target_modules) == 0 else config.target_modules, + lora_dropout=config.lora_dropout, + bias="none", + task_type=TaskType.SEQ_2_SEQ_LM, + ) + if config.use_int8: + model = prepare_model_for_int8_training(model) + + model = get_peft_model(model, lora_config) + + train_data = Seq2SeqDataset(data=train_data, tokenizer=tokenizer, config=config) + if config.valid_split is not None: + valid_data = Seq2SeqDataset(data=valid_data, tokenizer=tokenizer, config=config) + + if config.logging_steps == -1: + if config.valid_split is not None: + logging_steps = int(0.2 * len(valid_data) / config.batch_size) + else: + logging_steps = int(0.2 * len(train_data) / config.batch_size) + if logging_steps == 0: + logging_steps = 1 + + else: + logging_steps = config.logging_steps + + training_args = dict( + output_dir=config.project_name, + per_device_train_batch_size=config.batch_size, + per_device_eval_batch_size=2 * config.batch_size, + learning_rate=config.lr, + num_train_epochs=config.epochs, + fp16=config.fp16, + evaluation_strategy=config.evaluation_strategy if config.valid_split is not None else "no", + logging_steps=logging_steps, + save_total_limit=config.save_total_limit, + save_strategy=config.save_strategy, + gradient_accumulation_steps=config.gradient_accumulation, + report_to="tensorboard", + auto_find_batch_size=config.auto_find_batch_size, + lr_scheduler_type=config.scheduler, + optim=config.optimizer, + warmup_ratio=config.warmup_ratio, + weight_decay=config.weight_decay, + max_grad_norm=config.max_grad_norm, + push_to_hub=False, + load_best_model_at_end=True if config.valid_split is not None else False, + ddp_find_unused_parameters=False, + predict_with_generate=True, + seed=config.seed, + ) + + if config.valid_split is not None: + early_stop = EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01) + callbacks_to_use = [early_stop] + else: + callbacks_to_use = [] + + args = Seq2SeqTrainingArguments(**training_args) + _s2s_metrics = partial(utils._seq2seq_metrics, tokenizer=tokenizer) + trainer_args = dict( + args=args, + model=model, + callbacks=callbacks_to_use, + compute_metrics=_s2s_metrics, + ) + data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) + trainer = Seq2SeqTrainer( + **trainer_args, + train_dataset=train_data, + eval_dataset=valid_data, + data_collator=data_collator, + tokenizer=tokenizer, + ) + + trainer.train() + + logger.info("Finished training, saving model...") + trainer.save_model(config.project_name) + tokenizer.save_pretrained(config.project_name) + + model_card = utils.create_model_card(config, trainer) + + # remove token key from training_params.json located in output directory + # first check if file exists + if os.path.exists(f"{config.project_name}/training_params.json"): + training_params = json.load(open(f"{config.project_name}/training_params.json")) + if "token" in training_params: + training_params.pop("token") + json.dump( + training_params, + open(f"{config.project_name}/training_params.json", "w"), + ) + + # save model card to output directory as README.md + with open(f"{config.project_name}/README.md", "w") as f: + f.write(model_card) + + if config.push_to_hub: + if PartialState().process_index == 0: + logger.info("Pushing model to hub...") + api = HfApi(token=config.token) + api.create_repo(repo_id=config.repo_id, repo_type="model", private=True) + api.upload_folder( + folder_path=config.project_name, + repo_id=config.repo_id, + repo_type="model", + ) + + if PartialState().process_index == 0: + if "SPACE_ID" in os.environ: + # shut down the space + logger.info("Pausing space...") + api = HfApi(token=config.token) + api.pause_space(repo_id=os.environ["SPACE_ID"]) + + if "ENDPOINT_ID" in os.environ: + # shut down the endpoint + logger.info("Pausing endpoint...") + utils.pause_endpoint(config) + + +if __name__ == "__main__": + args = parse_args() + training_config = json.load(open(args.training_config)) + config = Seq2SeqParams(**training_config) + train(config) diff --git a/src/autotrain/trainers/seq2seq/dataset.py b/src/autotrain/trainers/seq2seq/dataset.py new file mode 100644 index 0000000000..f0ca839afe --- /dev/null +++ b/src/autotrain/trainers/seq2seq/dataset.py @@ -0,0 +1,21 @@ +class Seq2SeqDataset: + def __init__(self, data, tokenizer, config): + self.data = data + self.tokenizer = tokenizer + self.config = config + self.max_len_input = self.config.max_seq_length + self.max_len_target = self.config.max_target_length + + def __len__(self): + return len(self.data) + + def __getitem__(self, item): + text = str(self.data[item][self.config.text_column]) + target = str(self.data[item][self.config.target_column]) + + model_inputs = self.tokenizer(text, max_length=self.max_len_input, truncation=True) + + labels = self.tokenizer(text_target=target, max_length=self.max_len_target, truncation=True) + + model_inputs["labels"] = labels["input_ids"] + return model_inputs diff --git a/src/autotrain/trainers/seq2seq/params.py b/src/autotrain/trainers/seq2seq/params.py new file mode 100644 index 0000000000..b22b685d4c --- /dev/null +++ b/src/autotrain/trainers/seq2seq/params.py @@ -0,0 +1,45 @@ +from typing import List + +from pydantic import Field + +from autotrain.trainers.common import AutoTrainParams + + +class Seq2SeqParams(AutoTrainParams): + data_path: str = Field(None, title="Data path") + model: str = Field("google/flan-t5-base", title="Model name") + username: str = Field(None, title="Hugging Face Username") + seed: int = Field(42, title="Seed") + train_split: str = Field("train", title="Train split") + valid_split: str = Field(None, title="Validation split") + project_name: str = Field("Project Name", title="Output directory") + token: str = Field(None, title="Hub Token") + push_to_hub: bool = Field(False, title="Push to hub") + text_column: str = Field("text", title="Text column") + target_column: str = Field("target", title="Target text column") + repo_id: str = Field(None, title="Repo ID") + lr: float = Field(5e-5, title="Learning rate") + epochs: int = Field(3, title="Number of training epochs") + max_seq_length: int = Field(128, title="Max sequence length") + max_target_length: int = Field(128, title="Max target sequence length") + batch_size: int = Field(8, title="Training batch size") + warmup_ratio: float = Field(0.1, title="Warmup proportion") + gradient_accumulation: int = Field(1, title="Gradient accumulation steps") + optimizer: str = Field("adamw_torch", title="Optimizer") + scheduler: str = Field("linear", title="Scheduler") + weight_decay: float = Field(0.0, title="Weight decay") + max_grad_norm: float = Field(1.0, title="Max gradient norm") + logging_steps: int = Field(-1, title="Logging steps") + evaluation_strategy: str = Field("epoch", title="Evaluation strategy") + auto_find_batch_size: bool = Field(False, title="Auto find batch size") + fp16: bool = Field(False, title="Enable fp16") + save_total_limit: int = Field(1, title="Save total limit") + save_strategy: str = Field("epoch", title="Save strategy") + token: str = Field(None, title="Hub Token") + push_to_hub: bool = Field(False, title="Push to hub") + use_peft: bool = Field(False, title="Use PEFT") + use_int8: bool = Field(False, title="Use INT8") + lora_r: int = Field(16, title="LoRA-R") + lora_alpha: int = Field(32, title="LoRA-Alpha") + lora_dropout: float = Field(0.05, title="LoRA-Dropout") + target_modules: List[str] = Field([], title="Target modules for PEFT") diff --git a/src/autotrain/trainers/seq2seq/utils.py b/src/autotrain/trainers/seq2seq/utils.py new file mode 100644 index 0000000000..019f2e5ff9 --- /dev/null +++ b/src/autotrain/trainers/seq2seq/utils.py @@ -0,0 +1,60 @@ +import evaluate +import nltk +import numpy as np + + +ROUGE_METRIC = evaluate.load("rouge") + +MODEL_CARD = """ +--- +tags: +- autotrain +- text2text-generation +widget: +- text: "I love AutoTrain" +datasets: +- {dataset} +--- + +# Model Trained Using AutoTrain + +- Problem type: Seq2Seq + +## Validation Metrics +{validation_metrics} +""" + + +def _seq2seq_metrics(pred, tokenizer): + predictions, labels = pred + decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) + + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds] + decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels] + + result = ROUGE_METRIC.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) + result = {key: value.mid.fmeasure * 100 for key, value in result.items()} + + prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions] + result["gen_len"] = np.mean(prediction_lens) + + return {k: round(v, 4) for k, v in result.items()} + + +def create_model_card(config, trainer): + if config.valid_split is not None: + eval_scores = trainer.evaluate() + eval_scores = [f"{k[len('eval_'):]}: {v}" for k, v in eval_scores.items()] + eval_scores = "\n\n".join(eval_scores) + + else: + eval_scores = "No validation metrics available" + + model_card = MODEL_CARD.format( + dataset=config.data_path, + validation_metrics=eval_scores, + ) + return model_card