From 5b3cb77421c741ded88b855204b86540f48c16c6 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Thu, 9 May 2024 14:47:10 +0200 Subject: [PATCH] orpo improvements --- configs/llm_finetuning/llama3-70b-orpo-v1.yml | 1 + .../llm_finetuning/llama3-8b-dpo-qlora.yml | 37 +++++++++++++++++++ .../llm_finetuning/llama3-8b-orpo-space.yml | 1 + configs/llm_finetuning/llama3-8b-orpo.yml | 1 + docs/source/col_map.mdx | 4 +- src/autotrain/__init__.py | 2 +- src/autotrain/cli/run_llm.py | 3 +- src/autotrain/templates/index.html | 2 +- src/autotrain/trainers/clm/utils.py | 18 +++++++-- 9 files changed, 60 insertions(+), 9 deletions(-) create mode 100644 configs/llm_finetuning/llama3-8b-dpo-qlora.yml diff --git a/configs/llm_finetuning/llama3-70b-orpo-v1.yml b/configs/llm_finetuning/llama3-70b-orpo-v1.yml index ac2e9d5e8e..33861aa455 100644 --- a/configs/llm_finetuning/llama3-70b-orpo-v1.yml +++ b/configs/llm_finetuning/llama3-70b-orpo-v1.yml @@ -12,6 +12,7 @@ data: column_mapping: text_column: chosen rejected_text_column: rejected + prompt_text_column: prompt params: trainer: orpo diff --git a/configs/llm_finetuning/llama3-8b-dpo-qlora.yml b/configs/llm_finetuning/llama3-8b-dpo-qlora.yml new file mode 100644 index 0000000000..15e019cdfd --- /dev/null +++ b/configs/llm_finetuning/llama3-8b-dpo-qlora.yml @@ -0,0 +1,37 @@ +task: llm +base_model: meta-llama/Meta-Llama-3-8B-Instruct +project_name: autotrain-llama3-8b-dpo-qlora +log: tensorboard +backend: local + +data: + path: mlabonne/orpo-dpo-mix-40k + train_split: train + valid_split: null + chat_template: chatml + column_mapping: + text_column: chosen + rejected_text_column: rejected + prompt_text_column: prompt + +params: + trainer: dpo + block_size: 1024 + model_max_length: 2048 + max_prompt_length: 512 + epochs: 3 + batch_size: 2 + lr: 3e-5 + peft: true + quantization: int4 + target_modules: all-linear + padding: right + optimizer: adamw_torch + scheduler: linear + gradient_accumulation: 4 + mixed_precision: fp16 + +hub: + username: ${HF_USERNAME} + token: ${HF_TOKEN} + push_to_hub: true \ No newline at end of file diff --git a/configs/llm_finetuning/llama3-8b-orpo-space.yml b/configs/llm_finetuning/llama3-8b-orpo-space.yml index e96ce26632..c732212824 100644 --- a/configs/llm_finetuning/llama3-8b-orpo-space.yml +++ b/configs/llm_finetuning/llama3-8b-orpo-space.yml @@ -12,6 +12,7 @@ data: column_mapping: text_column: chosen rejected_text_column: rejected + prompt_text_column: prompt params: trainer: orpo diff --git a/configs/llm_finetuning/llama3-8b-orpo.yml b/configs/llm_finetuning/llama3-8b-orpo.yml index 6ed153cb25..98ca5b42e3 100644 --- a/configs/llm_finetuning/llama3-8b-orpo.yml +++ b/configs/llm_finetuning/llama3-8b-orpo.yml @@ -12,6 +12,7 @@ data: column_mapping: text_column: chosen rejected_text_column: rejected + prompt_text_column: prompt params: trainer: orpo diff --git a/docs/source/col_map.mdx b/docs/source/col_map.mdx index d681ae2c57..2a26b61b65 100644 --- a/docs/source/col_map.mdx +++ b/docs/source/col_map.mdx @@ -67,7 +67,7 @@ should use `chat_template` parameter. Read more about it in LLM Parameters Secti `text`: The column in your dataset that contains the text data. -### Reward / ORPO Trainer +### Reward Trainer ``` {"text": "text", "rejected_text": "rejected_text"} @@ -77,7 +77,7 @@ should use `chat_template` parameter. Read more about it in LLM Parameters Secti `rejected_text`: The column in your dataset that contains the rejected text data. -### DPO Trainer +### DPO / ORPO Trainer ``` {"prompt": "prompt", "text": "text", "rejected_text": "rejected_text"} diff --git a/src/autotrain/__init__.py b/src/autotrain/__init__.py index 87a33b419f..e3afdc7a43 100644 --- a/src/autotrain/__init__.py +++ b/src/autotrain/__init__.py @@ -41,4 +41,4 @@ logger = Logger().get_logger() -__version__ = "0.7.84.dev0" +__version__ = "0.7.84" diff --git a/src/autotrain/cli/run_llm.py b/src/autotrain/cli/run_llm.py index a28b3d64e0..68892d58b4 100644 --- a/src/autotrain/cli/run_llm.py +++ b/src/autotrain/cli/run_llm.py @@ -33,7 +33,7 @@ def register_subcommand(parser: ArgumentParser): "alias": ["--rejected-text-column"], }, { - "arg": "--prompt-text-column", + "arg": "--prompt_text_column", "help": "Identify the column that contains prompt text for tasks requiring contextual inputs, such as conversation or completion generation. Default is 'prompt'. Used only for dpo trainer", "required": False, "type": str, @@ -45,6 +45,7 @@ def register_subcommand(parser: ArgumentParser): "help": "Reference model to use for DPO when not using PEFT", "required": False, "type": str, + "alias": ["--model-ref"], }, { "arg": "--warmup_ratio", diff --git a/src/autotrain/templates/index.html b/src/autotrain/templates/index.html index 4cd6a520d4..9b7b14bf8b 100644 --- a/src/autotrain/templates/index.html +++ b/src/autotrain/templates/index.html @@ -735,7 +735,7 @@

document.getElementById("valid_split").disabled = true; break; case 'llm:orpo': - placeholderText = '{"text": "text", "rejected_text": "rejected_text"}'; + placeholderText = '{"prompt": "prompt", "text": "text", "rejected_text": "rejected_text"}'; document.getElementById("hub-dataset-radio").disabled = false; document.getElementById("valid_split").disabled = true; break; diff --git a/src/autotrain/trainers/clm/utils.py b/src/autotrain/trainers/clm/utils.py index 56f8d58162..66880c4053 100644 --- a/src/autotrain/trainers/clm/utils.py +++ b/src/autotrain/trainers/clm/utils.py @@ -224,7 +224,7 @@ def apply_chat_template( messages, tokenize=False, add_generation_prompt=False ) - elif config.trainer in ("reward", "orpo"): + elif config.trainer == "reward": if all(k in example.keys() for k in ("chosen", "rejected")): chosen_messages = example["chosen"] rejected_messages = example["rejected"] @@ -232,13 +232,19 @@ def apply_chat_template( chosen_messages = ast.literal_eval(chosen_messages) if isinstance(rejected_messages, str): rejected_messages = ast.literal_eval(rejected_messages) + + if config.chat_template == "zephyr" and chosen_messages[0]["role"] != "system": + chosen_messages.insert(0, {"role": "system", "content": ""}) + if config.chat_template == "zephyr" and rejected_messages[0]["role"] != "system": + rejected_messages.insert(0, {"role": "system", "content": ""}) + example["chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) example["rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) else: raise ValueError( f"Could not format example as dialogue for `rm/orpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" ) - elif config.trainer == "dpo": + elif config.trainer in ("dpo", "orpo"): if all(k in example.keys() for k in ("chosen", "rejected")): # For DPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue # We therefore need to extract the N-1 turns to form the prompt @@ -247,6 +253,8 @@ def apply_chat_template( if isinstance(example["rejected"], str): example["rejected"] = ast.literal_eval(example["rejected"]) prompt_messages = example["chosen"][:-1] + if config.chat_template == "zephyr" and example["chosen"][0]["role"] != "system": + prompt_messages.insert(0, {"role": "system", "content": ""}) chosen_messages = example["chosen"][-1:] rejected_messages = example["rejected"][-1:] example["chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) @@ -322,7 +330,7 @@ def process_input_data(config): train_data = train_data.rename_column(config.text_column, "chosen") if not (config.rejected_text_column == "rejected" and config.rejected_text_column in train_data.column_names): train_data = train_data.rename_column(config.rejected_text_column, "rejected") - if config.trainer == "dpo": + if config.trainer in ("dpo", "orpo"): if not (config.prompt_text_column == "prompt" and config.prompt_text_column in train_data.column_names): train_data = train_data.rename_column(config.prompt_text_column, "prompt") @@ -343,7 +351,7 @@ def process_input_data(config): config.rejected_text_column == "rejected" and config.rejected_text_column in valid_data.column_names ): valid_data = valid_data.rename_column(config.rejected_text_column, "rejected") - if config.trainer == "dpo": + if config.trainer in ("dpo", "reward"): if not (config.prompt_text_column == "prompt" and config.prompt_text_column in valid_data.column_names): valid_data = valid_data.rename_column(config.prompt_text_column, "prompt") else: @@ -401,6 +409,8 @@ def get_tokenizer(config): def process_data_with_chat_template(config, tokenizer, train_data, valid_data): valid_data = None if config.chat_template in ("chatml", "zephyr", "tokenizer"): + logger.info("Applying chat template") + logger.info("For ORPO/DPO, `prompt` will be extracted from chosen messages") train_data = train_data.map( apply_chat_template, fn_kwargs={