Skip to content

Commit

Permalink
orpo improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed May 9, 2024
1 parent af5357d commit 5b3cb77
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 9 deletions.
1 change: 1 addition & 0 deletions configs/llm_finetuning/llama3-70b-orpo-v1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ data:
column_mapping:
text_column: chosen
rejected_text_column: rejected
prompt_text_column: prompt

params:
trainer: orpo
Expand Down
37 changes: 37 additions & 0 deletions configs/llm_finetuning/llama3-8b-dpo-qlora.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
task: llm
base_model: meta-llama/Meta-Llama-3-8B-Instruct
project_name: autotrain-llama3-8b-dpo-qlora
log: tensorboard
backend: local

data:
path: mlabonne/orpo-dpo-mix-40k
train_split: train
valid_split: null
chat_template: chatml
column_mapping:
text_column: chosen
rejected_text_column: rejected
prompt_text_column: prompt

params:
trainer: dpo
block_size: 1024
model_max_length: 2048
max_prompt_length: 512
epochs: 3
batch_size: 2
lr: 3e-5
peft: true
quantization: int4
target_modules: all-linear
padding: right
optimizer: adamw_torch
scheduler: linear
gradient_accumulation: 4
mixed_precision: fp16

hub:
username: ${HF_USERNAME}
token: ${HF_TOKEN}
push_to_hub: true
1 change: 1 addition & 0 deletions configs/llm_finetuning/llama3-8b-orpo-space.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ data:
column_mapping:
text_column: chosen
rejected_text_column: rejected
prompt_text_column: prompt

params:
trainer: orpo
Expand Down
1 change: 1 addition & 0 deletions configs/llm_finetuning/llama3-8b-orpo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ data:
column_mapping:
text_column: chosen
rejected_text_column: rejected
prompt_text_column: prompt

params:
trainer: orpo
Expand Down
4 changes: 2 additions & 2 deletions docs/source/col_map.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ should use `chat_template` parameter. Read more about it in LLM Parameters Secti
`text`: The column in your dataset that contains the text data.


### Reward / ORPO Trainer
### Reward Trainer

```
{"text": "text", "rejected_text": "rejected_text"}
Expand All @@ -77,7 +77,7 @@ should use `chat_template` parameter. Read more about it in LLM Parameters Secti

`rejected_text`: The column in your dataset that contains the rejected text data.

### DPO Trainer
### DPO / ORPO Trainer

```
{"prompt": "prompt", "text": "text", "rejected_text": "rejected_text"}
Expand Down
2 changes: 1 addition & 1 deletion src/autotrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@


logger = Logger().get_logger()
__version__ = "0.7.84.dev0"
__version__ = "0.7.84"
3 changes: 2 additions & 1 deletion src/autotrain/cli/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def register_subcommand(parser: ArgumentParser):
"alias": ["--rejected-text-column"],
},
{
"arg": "--prompt-text-column",
"arg": "--prompt_text_column",
"help": "Identify the column that contains prompt text for tasks requiring contextual inputs, such as conversation or completion generation. Default is 'prompt'. Used only for dpo trainer",
"required": False,
"type": str,
Expand All @@ -45,6 +45,7 @@ def register_subcommand(parser: ArgumentParser):
"help": "Reference model to use for DPO when not using PEFT",
"required": False,
"type": str,
"alias": ["--model-ref"],
},
{
"arg": "--warmup_ratio",
Expand Down
2 changes: 1 addition & 1 deletion src/autotrain/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ <h3 class="mb-5 text-sm font-normal text-gray-800"></h3>
document.getElementById("valid_split").disabled = true;
break;
case 'llm:orpo':
placeholderText = '{"text": "text", "rejected_text": "rejected_text"}';
placeholderText = '{"prompt": "prompt", "text": "text", "rejected_text": "rejected_text"}';
document.getElementById("hub-dataset-radio").disabled = false;
document.getElementById("valid_split").disabled = true;
break;
Expand Down
18 changes: 14 additions & 4 deletions src/autotrain/trainers/clm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,27 @@ def apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)

elif config.trainer in ("reward", "orpo"):
elif config.trainer == "reward":
if all(k in example.keys() for k in ("chosen", "rejected")):
chosen_messages = example["chosen"]
rejected_messages = example["rejected"]
if isinstance(chosen_messages, str):
chosen_messages = ast.literal_eval(chosen_messages)
if isinstance(rejected_messages, str):
rejected_messages = ast.literal_eval(rejected_messages)

if config.chat_template == "zephyr" and chosen_messages[0]["role"] != "system":
chosen_messages.insert(0, {"role": "system", "content": ""})
if config.chat_template == "zephyr" and rejected_messages[0]["role"] != "system":
rejected_messages.insert(0, {"role": "system", "content": ""})

example["chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
example["rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
else:
raise ValueError(
f"Could not format example as dialogue for `rm/orpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
)
elif config.trainer == "dpo":
elif config.trainer in ("dpo", "orpo"):
if all(k in example.keys() for k in ("chosen", "rejected")):
# For DPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue
# We therefore need to extract the N-1 turns to form the prompt
Expand All @@ -247,6 +253,8 @@ def apply_chat_template(
if isinstance(example["rejected"], str):
example["rejected"] = ast.literal_eval(example["rejected"])
prompt_messages = example["chosen"][:-1]
if config.chat_template == "zephyr" and example["chosen"][0]["role"] != "system":
prompt_messages.insert(0, {"role": "system", "content": ""})
chosen_messages = example["chosen"][-1:]
rejected_messages = example["rejected"][-1:]
example["chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
Expand Down Expand Up @@ -322,7 +330,7 @@ def process_input_data(config):
train_data = train_data.rename_column(config.text_column, "chosen")
if not (config.rejected_text_column == "rejected" and config.rejected_text_column in train_data.column_names):
train_data = train_data.rename_column(config.rejected_text_column, "rejected")
if config.trainer == "dpo":
if config.trainer in ("dpo", "orpo"):
if not (config.prompt_text_column == "prompt" and config.prompt_text_column in train_data.column_names):
train_data = train_data.rename_column(config.prompt_text_column, "prompt")

Expand All @@ -343,7 +351,7 @@ def process_input_data(config):
config.rejected_text_column == "rejected" and config.rejected_text_column in valid_data.column_names
):
valid_data = valid_data.rename_column(config.rejected_text_column, "rejected")
if config.trainer == "dpo":
if config.trainer in ("dpo", "reward"):
if not (config.prompt_text_column == "prompt" and config.prompt_text_column in valid_data.column_names):
valid_data = valid_data.rename_column(config.prompt_text_column, "prompt")
else:
Expand Down Expand Up @@ -401,6 +409,8 @@ def get_tokenizer(config):
def process_data_with_chat_template(config, tokenizer, train_data, valid_data):
valid_data = None
if config.chat_template in ("chatml", "zephyr", "tokenizer"):
logger.info("Applying chat template")
logger.info("For ORPO/DPO, `prompt` will be extracted from chosen messages")
train_data = train_data.map(
apply_chat_template,
fn_kwargs={
Expand Down

0 comments on commit 5b3cb77

Please sign in to comment.