Skip to content

Commit

Permalink
add orpo trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Apr 19, 2024
1 parent 213adba commit b5348e3
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 28 deletions.
32 changes: 13 additions & 19 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
@@ -1,43 +1,37 @@
![autotrain-homepage](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/ui.png)
![autotrain-homepage](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/autotrain_homepage.png)



🤗 AutoTrain is a no-code tool for training state-of-the-art models for Natural Language Processing (NLP) tasks, for Computer Vision (CV) tasks, and for Speech tasks and even for Tabular tasks. It is built on top of the awesome tools developed by the Hugging Face team, and it is designed to be easy to use.
🤗 AutoTrain Advanced (or simply AutoTrain) is a no-code tool for training state-of-the-art models for Natural Language Processing (NLP) tasks, for Computer Vision (CV) tasks, and for Speech tasks and even for Tabular tasks. It is built on top of the awesome tools developed by the Hugging Face team, and it is designed to be easy to use.

## Who should use AutoTrain?

AutoTrain is for anyone who wants to train a state-of-the-art model for a NLP, CV, Speech or even Tabular task, but doesn't want to spend time on the technical details of training a model. AutoTrain is also for anyone who wants to train a model for a custom dataset, but doesn't want to spend time on the technical details of training a model. Our goal is to make it easy for anyone to train a state-of-the-art model for any task and our focus is not just data scientists or machine learning engineers, but also non-technical users.


## What is AutoTrain Advanced?

AutoTrain Advanced processes your data either in a Hugging Face Space or locally (if installed locally using pip). This saves one time since the data processing is not done by the AutoTrain backend, resulting in your job not being queued. AutoTrain Advanced also allows you to use your own hardware (better CPU and RAM) to process the data, thus, making the data processing faster.

Using AutoTrain Advanced, advanced users can also control the hyperparameters used for training per job. This allows you to train multiple models with different hyperparameters and compare the results.

Everything else is the same as AutoTrain. You can use AutoTrain Advanced to train models for NLP, CV, Speech and Tabular tasks.

We recommend using [AutoTrain Advanced](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced?duplicate=true)
since it is faster, more flexible and will have more supported tasks and features in the future.
AutoTrain is for anyone who wants to train a state-of-the-art model for a NLP, CV, Speech or even Tabular task, but doesn't want to spend time on the technical details of training a model.
AutoTrain is also for anyone who wants to train a model for a custom dataset, but doesn't want to spend time on the technical details of training a model.
Our goal is to make it easy for anyone to train a state-of-the-art model for any task and our focus is not just data scientists or machine learning engineers, but also non-technical users.


## How to use AutoTrain?

We offer several ways to use AutoTrain:

- No code users can use `AutoTrain Advanced` by creating a new space with AutoTrain Docker image:
[Click here](https://huggingface.co/spaces/autotrain-projects/autotrain-advanced?duplicate=true) to create AutoTrain Space.
[Click here](https://huggingface.co/login?next=/spaces/autotrain-projects/autotrain-advanced?duplicate=true) to create AutoTrain Space.
Please make sure you keep the space private and attach appropriate hardware to the space.

- Developers can access and build on top of AutoTrain using python api or run AutoTrain Advanced UI locally. The python api is available in the `autotrain-advanced` package. You can install it using pip:
- Developers can access and build on top of AutoTrain using python api or run AutoTrain Advanced UI locally.
The python api is available in the `autotrain-advanced` package.


You can install it using pip:

```bash
$ pip install autotrain-advanced
```

# Running the app locally
# Running AutoTrain Locally

To run the app locally, you can use the following command:
To run the autotrain app locally, you can use the following command:

```bash
$ export HF_TOKEN=your_hugging_face_write_token
Expand Down
2 changes: 1 addition & 1 deletion src/autotrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@


logger = Logger().get_logger()
__version__ = "0.7.54.dev0"
__version__ = "0.7.55.dev0"
14 changes: 14 additions & 0 deletions src/autotrain/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
padding="right",
chat_template="none",
save_strategy="no",
max_completion_length=128,
).model_dump()

PARAMS["text-classification"] = TextClassificationParams(
Expand Down Expand Up @@ -217,6 +218,14 @@ async def fetch_params(task: str, param_type: str):
if task == "llm":
more_hidden_params = []
if trainer in ("sft", "reward"):
more_hidden_params = [
"model_ref",
"dpo_beta",
"add_eos_token",
"max_prompt_length",
"max_completion_length",
]
elif trainer == "orpo":
more_hidden_params = [
"model_ref",
"dpo_beta",
Expand All @@ -226,10 +235,14 @@ async def fetch_params(task: str, param_type: str):
more_hidden_params = [
"model_ref",
"dpo_beta",
"max_prompt_length",
"max_completion_length",
]
elif trainer == "dpo":
more_hidden_params = [
"add_eos_token",
"max_prompt_length",
"max_completion_length",
]
if param_type == "basic":
more_hidden_params.extend(
Expand All @@ -251,6 +264,7 @@ async def fetch_params(task: str, param_type: str):
"lora_r",
"lora_alpha",
"lora_dropout",
"max_completion_length",
]
)
task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
Expand Down
6 changes: 6 additions & 0 deletions src/autotrain/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@
class="mt-1 block w-full border border-gray-300 px-3 py-2 bg-white rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500">
<optgroup label="LLM Finetuning">
<option value="llm:sft">LLM SFT</option>
<option value="llm:orpo">LLM ORPO</option>
<option value="llm:generic">LLM Generic</option>
<option value="llm:dpo">LLM DPO</option>
<option value="llm:reward">LLM Reward</option>
Expand Down Expand Up @@ -730,6 +731,11 @@ <h3 class="mb-5 text-sm font-normal text-gray-800"></h3>
document.getElementById("hub-dataset-radio").disabled = false;
document.getElementById("valid_split").disabled = true;
break;
case 'llm:orpo':
placeholderText = '{"text": "text", "rejected_text": "rejected_text"}';
document.getElementById("hub-dataset-radio").disabled = false;
document.getElementById("valid_split").disabled = true;
break;
case 'text-classification':
placeholderText = '{"text": "text", "label": "target"}';
document.getElementById("hub-dataset-radio").disabled = false;
Expand Down
24 changes: 18 additions & 6 deletions src/autotrain/trainers/clm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
default_data_collator,
)
from transformers.trainer_callback import PrinterCallback
from trl import DPOTrainer, RewardConfig, RewardTrainer, SFTTrainer
from trl import DPOTrainer, ORPOConfig, ORPOTrainer, RewardConfig, RewardTrainer, SFTTrainer

from autotrain import logger
from autotrain.trainers.clm import utils
Expand Down Expand Up @@ -82,7 +82,7 @@ def process_input_data(config):
token=config.token,
)
# rename columns for reward trainer
if config.trainer in ("dpo", "reward"):
if config.trainer in ("dpo", "reward", "orpo"):
if not (config.text_column == "chosen" and config.text_column in train_data.column_names):
train_data = train_data.rename_column(config.text_column, "chosen")
if not (config.rejected_text_column == "rejected" and config.rejected_text_column in train_data.column_names):
Expand All @@ -101,7 +101,7 @@ def process_input_data(config):
token=config.token,
)

if config.trainer in ("dpo", "reward"):
if config.trainer in ("dpo", "reward", "orpo"):
if not (config.text_column == "chosen" and config.text_column in valid_data.column_names):
valid_data = valid_data.rename_column(config.text_column, "chosen")
if not (
Expand Down Expand Up @@ -242,7 +242,6 @@ def train(config):
ddp_find_unused_parameters=False,
gradient_checkpointing=not config.disable_gradient_checkpointing,
remove_unused_columns=False,
disable_tqdm=True,
)

if not config.disable_gradient_checkpointing:
Expand All @@ -259,6 +258,11 @@ def train(config):
if config.trainer == "reward":
training_args["max_length"] = config.block_size
args = RewardConfig(**training_args)
elif config.trainer == "orpo":
training_args["max_length"] = config.block_size
training_args["max_prompt_length"] = config.max_prompt_length
training_args["max_completion_length"] = config.max_completion_length
args = ORPOConfig(**training_args)
else:
args = TrainingArguments(**training_args)

Expand Down Expand Up @@ -330,7 +334,7 @@ def train(config):
use_flash_attention_2=config.use_flash_attention_2,
torch_dtype=torch_dtype,
)
if config.model_ref is not None:
if config.model_ref is not None and config.trainer == "dpo":
model_ref = AutoModelForCausalLM.from_pretrained(
config.model_ref,
config=model_config,
Expand All @@ -343,7 +347,7 @@ def train(config):
model_ref = None

logger.info(f"model dtype: {model.dtype}")
if config.model_ref is not None:
if config.model_ref is not None and config.trainer == "dpo":
logger.info(f"model_ref dtype: {model_ref.dtype}")

model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
Expand Down Expand Up @@ -515,6 +519,14 @@ def train(config):
peft_config=peft_config if config.peft else None,
tokenizer=tokenizer,
)
elif config.trainer == "orpo":
trainer = ORPOTrainer(
**trainer_args,
train_dataset=train_data,
eval_dataset=valid_data if config.valid_split is not None else None,
tokenizer=tokenizer,
peft_config=peft_config if config.peft else None,
)
elif config.trainer == "dpo":
if isinstance(config.block_size, int):
max_length = config.block_size
Expand Down
4 changes: 4 additions & 0 deletions src/autotrain/trainers/clm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class LLMTrainingParams(AutoTrainParams):
model_ref: Optional[str] = Field(None, title="Reference, for DPO trainer")
dpo_beta: float = Field(0.1, title="Beta for DPO trainer")

# orpo
max_prompt_length: int = Field(128, title="Prompt length")
max_completion_length: Optional[int] = Field(None, title="Completion length")

# column mappings
prompt_text_column: Optional[str] = Field(None, title="Prompt text column")
text_column: str = Field("text", title="Text column")
Expand Down
4 changes: 2 additions & 2 deletions src/autotrain/trainers/clm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)

elif config.trainer == "reward":
elif config.trainer in ("reward", "orpo"):
if all(k in example.keys() for k in ("chosen", "rejected")):
chosen_messages = example["chosen"]
rejected_messages = example["rejected"]
Expand All @@ -205,7 +205,7 @@ def apply_chat_template(
example["rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
else:
raise ValueError(
f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
f"Could not format example as dialogue for `rm/orpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
)
elif config.trainer == "dpo":
if all(k in example.keys() for k in ("chosen", "rejected")):
Expand Down

0 comments on commit b5348e3

Please sign in to comment.