From b5348e31a8ddef7349117df18aba810d23ae566a Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Fri, 19 Apr 2024 15:26:05 +0200 Subject: [PATCH] add orpo trainer --- docs/source/index.mdx | 32 +++++++++++--------------- src/autotrain/__init__.py | 2 +- src/autotrain/app.py | 14 +++++++++++ src/autotrain/templates/index.html | 6 +++++ src/autotrain/trainers/clm/__main__.py | 24 ++++++++++++++----- src/autotrain/trainers/clm/params.py | 4 ++++ src/autotrain/trainers/clm/utils.py | 4 ++-- 7 files changed, 58 insertions(+), 28 deletions(-) diff --git a/docs/source/index.mdx b/docs/source/index.mdx index d03e10f6b6..ca89078eea 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -1,24 +1,14 @@ -![autotrain-homepage](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/ui.png) +![autotrain-homepage](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/autotrain_homepage.png) -🤗 AutoTrain is a no-code tool for training state-of-the-art models for Natural Language Processing (NLP) tasks, for Computer Vision (CV) tasks, and for Speech tasks and even for Tabular tasks. It is built on top of the awesome tools developed by the Hugging Face team, and it is designed to be easy to use. +🤗 AutoTrain Advanced (or simply AutoTrain) is a no-code tool for training state-of-the-art models for Natural Language Processing (NLP) tasks, for Computer Vision (CV) tasks, and for Speech tasks and even for Tabular tasks. It is built on top of the awesome tools developed by the Hugging Face team, and it is designed to be easy to use. ## Who should use AutoTrain? -AutoTrain is for anyone who wants to train a state-of-the-art model for a NLP, CV, Speech or even Tabular task, but doesn't want to spend time on the technical details of training a model. AutoTrain is also for anyone who wants to train a model for a custom dataset, but doesn't want to spend time on the technical details of training a model. Our goal is to make it easy for anyone to train a state-of-the-art model for any task and our focus is not just data scientists or machine learning engineers, but also non-technical users. - - -## What is AutoTrain Advanced? - -AutoTrain Advanced processes your data either in a Hugging Face Space or locally (if installed locally using pip). This saves one time since the data processing is not done by the AutoTrain backend, resulting in your job not being queued. AutoTrain Advanced also allows you to use your own hardware (better CPU and RAM) to process the data, thus, making the data processing faster. - -Using AutoTrain Advanced, advanced users can also control the hyperparameters used for training per job. This allows you to train multiple models with different hyperparameters and compare the results. - -Everything else is the same as AutoTrain. You can use AutoTrain Advanced to train models for NLP, CV, Speech and Tabular tasks. - -We recommend using [AutoTrain Advanced](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced?duplicate=true) -since it is faster, more flexible and will have more supported tasks and features in the future. +AutoTrain is for anyone who wants to train a state-of-the-art model for a NLP, CV, Speech or even Tabular task, but doesn't want to spend time on the technical details of training a model. +AutoTrain is also for anyone who wants to train a model for a custom dataset, but doesn't want to spend time on the technical details of training a model. +Our goal is to make it easy for anyone to train a state-of-the-art model for any task and our focus is not just data scientists or machine learning engineers, but also non-technical users. ## How to use AutoTrain? @@ -26,18 +16,22 @@ since it is faster, more flexible and will have more supported tasks and feature We offer several ways to use AutoTrain: - No code users can use `AutoTrain Advanced` by creating a new space with AutoTrain Docker image: -[Click here](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced?duplicate=true) to create AutoTrain Space. +[Click here](https://huggingface.co/login?next=/spaces/autotrain-projects/autotrain-advanced?duplicate=true) to create AutoTrain Space. Please make sure you keep the space private and attach appropriate hardware to the space. -- Developers can access and build on top of AutoTrain using python api or run AutoTrain Advanced UI locally. The python api is available in the `autotrain-advanced` package. You can install it using pip: +- Developers can access and build on top of AutoTrain using python api or run AutoTrain Advanced UI locally. +The python api is available in the `autotrain-advanced` package. + + +You can install it using pip: ```bash $ pip install autotrain-advanced ``` -# Running the app locally +# Running AutoTrain Locally -To run the app locally, you can use the following command: +To run the autotrain app locally, you can use the following command: ```bash $ export HF_TOKEN=your_hugging_face_write_token diff --git a/src/autotrain/__init__.py b/src/autotrain/__init__.py index 49d8550e54..a16f8067d4 100644 --- a/src/autotrain/__init__.py +++ b/src/autotrain/__init__.py @@ -41,4 +41,4 @@ logger = Logger().get_logger() -__version__ = "0.7.54.dev0" +__version__ = "0.7.55.dev0" diff --git a/src/autotrain/app.py b/src/autotrain/app.py index 17d4cafcfa..cd8f45e5f5 100644 --- a/src/autotrain/app.py +++ b/src/autotrain/app.py @@ -97,6 +97,7 @@ padding="right", chat_template="none", save_strategy="no", + max_completion_length=128, ).model_dump() PARAMS["text-classification"] = TextClassificationParams( @@ -217,6 +218,14 @@ async def fetch_params(task: str, param_type: str): if task == "llm": more_hidden_params = [] if trainer in ("sft", "reward"): + more_hidden_params = [ + "model_ref", + "dpo_beta", + "add_eos_token", + "max_prompt_length", + "max_completion_length", + ] + elif trainer == "orpo": more_hidden_params = [ "model_ref", "dpo_beta", @@ -226,10 +235,14 @@ async def fetch_params(task: str, param_type: str): more_hidden_params = [ "model_ref", "dpo_beta", + "max_prompt_length", + "max_completion_length", ] elif trainer == "dpo": more_hidden_params = [ "add_eos_token", + "max_prompt_length", + "max_completion_length", ] if param_type == "basic": more_hidden_params.extend( @@ -251,6 +264,7 @@ async def fetch_params(task: str, param_type: str): "lora_r", "lora_alpha", "lora_dropout", + "max_completion_length", ] ) task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} diff --git a/src/autotrain/templates/index.html b/src/autotrain/templates/index.html index 3a5e608453..e2c0290bd1 100644 --- a/src/autotrain/templates/index.html +++ b/src/autotrain/templates/index.html @@ -296,6 +296,7 @@ class="mt-1 block w-full border border-gray-300 px-3 py-2 bg-white rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500"> + @@ -730,6 +731,11 @@

document.getElementById("hub-dataset-radio").disabled = false; document.getElementById("valid_split").disabled = true; break; + case 'llm:orpo': + placeholderText = '{"text": "text", "rejected_text": "rejected_text"}'; + document.getElementById("hub-dataset-radio").disabled = false; + document.getElementById("valid_split").disabled = true; + break; case 'text-classification': placeholderText = '{"text": "text", "label": "target"}'; document.getElementById("hub-dataset-radio").disabled = false; diff --git a/src/autotrain/trainers/clm/__main__.py b/src/autotrain/trainers/clm/__main__.py index b72f75ffad..a3146aaf17 100644 --- a/src/autotrain/trainers/clm/__main__.py +++ b/src/autotrain/trainers/clm/__main__.py @@ -21,7 +21,7 @@ default_data_collator, ) from transformers.trainer_callback import PrinterCallback -from trl import DPOTrainer, RewardConfig, RewardTrainer, SFTTrainer +from trl import DPOTrainer, ORPOConfig, ORPOTrainer, RewardConfig, RewardTrainer, SFTTrainer from autotrain import logger from autotrain.trainers.clm import utils @@ -82,7 +82,7 @@ def process_input_data(config): token=config.token, ) # rename columns for reward trainer - if config.trainer in ("dpo", "reward"): + if config.trainer in ("dpo", "reward", "orpo"): if not (config.text_column == "chosen" and config.text_column in train_data.column_names): train_data = train_data.rename_column(config.text_column, "chosen") if not (config.rejected_text_column == "rejected" and config.rejected_text_column in train_data.column_names): @@ -101,7 +101,7 @@ def process_input_data(config): token=config.token, ) - if config.trainer in ("dpo", "reward"): + if config.trainer in ("dpo", "reward", "orpo"): if not (config.text_column == "chosen" and config.text_column in valid_data.column_names): valid_data = valid_data.rename_column(config.text_column, "chosen") if not ( @@ -242,7 +242,6 @@ def train(config): ddp_find_unused_parameters=False, gradient_checkpointing=not config.disable_gradient_checkpointing, remove_unused_columns=False, - disable_tqdm=True, ) if not config.disable_gradient_checkpointing: @@ -259,6 +258,11 @@ def train(config): if config.trainer == "reward": training_args["max_length"] = config.block_size args = RewardConfig(**training_args) + elif config.trainer == "orpo": + training_args["max_length"] = config.block_size + training_args["max_prompt_length"] = config.max_prompt_length + training_args["max_completion_length"] = config.max_completion_length + args = ORPOConfig(**training_args) else: args = TrainingArguments(**training_args) @@ -330,7 +334,7 @@ def train(config): use_flash_attention_2=config.use_flash_attention_2, torch_dtype=torch_dtype, ) - if config.model_ref is not None: + if config.model_ref is not None and config.trainer == "dpo": model_ref = AutoModelForCausalLM.from_pretrained( config.model_ref, config=model_config, @@ -343,7 +347,7 @@ def train(config): model_ref = None logger.info(f"model dtype: {model.dtype}") - if config.model_ref is not None: + if config.model_ref is not None and config.trainer == "dpo": logger.info(f"model_ref dtype: {model_ref.dtype}") model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) @@ -515,6 +519,14 @@ def train(config): peft_config=peft_config if config.peft else None, tokenizer=tokenizer, ) + elif config.trainer == "orpo": + trainer = ORPOTrainer( + **trainer_args, + train_dataset=train_data, + eval_dataset=valid_data if config.valid_split is not None else None, + tokenizer=tokenizer, + peft_config=peft_config if config.peft else None, + ) elif config.trainer == "dpo": if isinstance(config.block_size, int): max_length = config.block_size diff --git a/src/autotrain/trainers/clm/params.py b/src/autotrain/trainers/clm/params.py index c36e1eb264..fda2b722b6 100644 --- a/src/autotrain/trainers/clm/params.py +++ b/src/autotrain/trainers/clm/params.py @@ -54,6 +54,10 @@ class LLMTrainingParams(AutoTrainParams): model_ref: Optional[str] = Field(None, title="Reference, for DPO trainer") dpo_beta: float = Field(0.1, title="Beta for DPO trainer") + # orpo + max_prompt_length: int = Field(128, title="Prompt length") + max_completion_length: Optional[int] = Field(None, title="Completion length") + # column mappings prompt_text_column: Optional[str] = Field(None, title="Prompt text column") text_column: str = Field("text", title="Text column") diff --git a/src/autotrain/trainers/clm/utils.py b/src/autotrain/trainers/clm/utils.py index b452d34aef..8109b17ed1 100644 --- a/src/autotrain/trainers/clm/utils.py +++ b/src/autotrain/trainers/clm/utils.py @@ -193,7 +193,7 @@ def apply_chat_template( messages, tokenize=False, add_generation_prompt=False ) - elif config.trainer == "reward": + elif config.trainer in ("reward", "orpo"): if all(k in example.keys() for k in ("chosen", "rejected")): chosen_messages = example["chosen"] rejected_messages = example["rejected"] @@ -205,7 +205,7 @@ def apply_chat_template( example["rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) else: raise ValueError( - f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" + f"Could not format example as dialogue for `rm/orpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" ) elif config.trainer == "dpo": if all(k in example.keys() for k in ("chosen", "rejected")):