Skip to content

Commit

Permalink
dpo training
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Oct 30, 2023
1 parent 6f61306 commit d66e4a6
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 7 deletions.
37 changes: 35 additions & 2 deletions src/autotrain/cli/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,26 @@ def register_subcommand(parser: ArgumentParser):
"default": "rejected",
"alias": ["--rejected-text-column"],
},
{
"arg": "--prompt-text-column",
"help": "Prompt text column to use",
"required": False,
"type": str,
"default": "prompt",
"alias": ["--prompt-text-column"],
},
{
"arg": "--model",
"help": "Model to use",
"required": False,
"type": str,
},
{
"arg": "--model-ref",
"help": "Reference model to use for DPO when not using PEFT",
"required": False,
"type": str,
},
{
"arg": "--learning_rate",
"help": "Learning rate to use",
Expand Down Expand Up @@ -169,8 +183,8 @@ def register_subcommand(parser: ArgumentParser):
"arg": "--block_size",
"help": "Block size to use",
"required": False,
"type": int,
"default": -1,
"type": str,
"default": "-1",
"alias": ["--block-size"],
},
{
Expand Down Expand Up @@ -354,6 +368,14 @@ def register_subcommand(parser: ArgumentParser):
"action": "store_true",
"alias": ["--disable-gradient-checkpointing", "--disable-gc"],
},
{
"arg": "--dpo-beta",
"help": "Beta for DPO trainer",
"required": False,
"type": float,
"default": 0.1,
"alias": ["--dpo-beta"],
},
]
run_llm_parser = parser.add_parser("llm", description="✨ Run AutoTrain LLM")
for arg in arg_list:
Expand Down Expand Up @@ -400,6 +422,14 @@ def __init__(self, args):
if getattr(self.args, arg_name) is None:
setattr(self.args, arg_name, False)

block_size_split = self.args.block_size.strip().split(",")
if len(block_size_split) == 1:
self.args.block_size = int(block_size_split[0])
elif len(block_size_split) > 1:
self.args.block_size = [int(x.strip()) for x in block_size_split]
else:
raise ValueError("Invalid block size")

if self.args.train:
if self.args.project_name is None:
raise ValueError("Project name must be specified")
Expand Down Expand Up @@ -498,6 +528,9 @@ def run(self):
log=self.args.log,
rejected_text_column=self.args.rejected_text_column,
disable_gradient_checkpointing=self.args.disable_gradient_checkpointing,
model_ref=self.args.model_ref,
dpo_beta=self.args.dpo_beta,
prompt_text_column=self.args.prompt_text_column,
)

# space training
Expand Down
49 changes: 45 additions & 4 deletions src/autotrain/trainers/clm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
TrainingArguments,
default_data_collator,
)
from trl import RewardConfig, RewardTrainer, SFTTrainer
from trl import DPOTrainer, RewardConfig, RewardTrainer, SFTTrainer

from autotrain import logger
from autotrain.trainers.clm import utils
Expand Down Expand Up @@ -59,14 +59,16 @@ def train(config):
token=config.token,
)
# rename columns for reward trainer
if config.trainer == "reward":
if config.trainer in ("dpo", "reward"):
if not (config.text_column == "chosen" and config.text_column in train_data.column_names):
train_data = train_data.rename_column(config.text_column, "chosen")
if not (
config.rejected_text_column == "rejected" and config.rejected_text_column in train_data.column_names
):
train_data = train_data.rename_column(config.rejected_text_column, "rejected")

if config.trainer == "dpo":
if not (config.prompt_text_column == "prompt" and config.prompt_text_column in train_data.column_names):
train_data = train_data.rename_column(config.prompt_text_column, "prompt")
if config.valid_split is not None:
valid_path = f"{config.data_path}/{config.valid_split}.csv"
if os.path.exists(valid_path):
Expand All @@ -80,13 +82,16 @@ def train(config):
token=config.token,
)

if config.trainer == "reward":
if config.trainer in ("dpo", "reward"):
if not (config.text_column == "chosen" and config.text_column in valid_data.column_names):
valid_data = valid_data.rename_column(config.text_column, "chosen")
if not (
config.rejected_text_column == "rejected" and config.rejected_text_column in valid_data.column_names
):
valid_data = valid_data.rename_column(config.rejected_text_column, "rejected")
if config.trainer == "dpo":
if not (config.prompt_text_column == "prompt" and config.prompt_text_column in valid_data.column_names):
valid_data = valid_data.rename_column(config.prompt_text_column, "prompt")

tokenizer = AutoTokenizer.from_pretrained(
config.model,
Expand Down Expand Up @@ -163,6 +168,7 @@ def train(config):
trust_remote_code=True,
use_flash_attention_2=config.use_flash_attention_2,
)
model_ref = None
else:
if config.trainer == "reward":
model = AutoModelForSequenceClassification.from_pretrained(
Expand All @@ -179,6 +185,14 @@ def train(config):
trust_remote_code=True,
use_flash_attention_2=config.use_flash_attention_2,
)
if config.model_ref is not None:
model_ref = AutoModelForCausalLM.from_pretrained(
config.model_ref,
config=model_config,
token=config.token,
trust_remote_code=True,
use_flash_attention_2=config.use_flash_attention_2,
)

model.resize_token_embeddings(len(tokenizer))

Expand Down Expand Up @@ -377,6 +391,33 @@ def train(config):
peft_config=peft_config,
tokenizer=tokenizer,
)
elif config.trainer == "dpo":
if isinstance(config.block_size, int):
max_length = config.block_size
max_prompt_length = None
max_target_length = None
elif isinstance(config.block_size, list):
if len(config.block_size) == 3:
max_length, max_prompt_length, max_target_length = config.block_size
elif len(config.block_size) == 2:
max_length, max_prompt_length = config.block_size
max_target_length = None
else:
raise ValueError(f"block_size must be a list of length 2 or 3, got {config.block_size}")
else:
raise ValueError(f"block_size must be an int or a list, got {config.block_size}")
trainer = DPOTrainer(
**trainer_args,
ref_model=model_ref,
beta=config.dpo_beta,
train_dataset=train_data,
eval_dataset=valid_data if config.valid_split is not None else None,
tokenizer=tokenizer,
max_length=max_length,
max_prompt_length=max_prompt_length,
max_target_length=max_target_length,
peft_config=peft_config,
)
else:
raise ValueError(f"trainer `{config.trainer}` not supported")
model.config.use_cache = False
Expand Down
6 changes: 5 additions & 1 deletion src/autotrain/trainers/clm/params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import List, Union

from pydantic import BaseModel, Field

Expand All @@ -25,7 +26,7 @@ class LLMTrainingParams(BaseModel):
max_grad_norm: float = Field(1.0, title="Max gradient norm")
seed: int = Field(42, title="Seed")
add_eos_token: bool = Field(True, title="Add EOS token")
block_size: int = Field(-1, title="Block size")
block_size: Union[int, List[int]] = Field(-1, title="Block size")
use_peft: bool = Field(False, title="Use PEFT")
lora_r: int = Field(16, title="Lora r")
lora_alpha: int = Field(32, title="Lora alpha")
Expand All @@ -48,6 +49,9 @@ class LLMTrainingParams(BaseModel):
use_flash_attention_2: bool = Field(False, title="Use flash attention 2")
log: str = Field("none", title="Logging using experiment tracking")
disable_gradient_checkpointing: bool = Field(False, title="Gradient checkpointing")
model_ref: str = Field(None, title="Reference, for DPO trainer")
dpo_beta: float = Field(0.1, title="Beta for DPO trainer")
prompt_text_column: str = Field(None, title="Prompt text column")

def save(self, output_dir):
os.makedirs(output_dir, exist_ok=True)
Expand Down

0 comments on commit d66e4a6

Please sign in to comment.