Skip to content

Commit

Permalink
Reward modelling (#297)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur authored Oct 12, 2023
1 parent 1931728 commit d42c50f
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 38 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ test.py
output/
output2/
logs/
op_*/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
39 changes: 35 additions & 4 deletions src/autotrain/cli/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def register_subcommand(parser: ArgumentParser):
"default": "text",
"alias": ["--text-column"],
},
{
"arg": "--rejected_text_column",
"help": "Rejected text column to use",
"required": False,
"type": str,
"default": "rejected",
"alias": ["--rejected-text-column"],
},
{
"arg": "--model",
"help": "Model to use",
Expand Down Expand Up @@ -332,6 +340,13 @@ def register_subcommand(parser: ArgumentParser):
"action": "store_true",
"alias": ["--use-flash-attention-2", "--use-fa2"],
},
{
"arg": "--disable_gradient_checkpointing",
"help": "Disable gradient checkpointing",
"required": False,
"action": "store_true",
"alias": ["--disable-gradient-checkpointing", "--disable-gc"],
},
]
run_llm_parser = parser.add_parser("llm", description="✨ Run AutoTrain LLM")
for arg in arg_list:
Expand Down Expand Up @@ -372,6 +387,7 @@ def __init__(self, args):
"use_int4",
"merge_adapter",
"use_flash_attention_2",
"disable_gradient_checkpointing",
]
for arg_name in store_true_arg_names:
if getattr(self.args, arg_name) is None:
Expand All @@ -385,8 +401,12 @@ def __init__(self, args):
if self.args.model is None:
raise ValueError("Model must be specified")
if self.args.push_to_hub:
if self.args.repo_id is None:
raise ValueError("Repo id must be specified for push to hub")
# must have project_name, username and token OR project_name, repo_id, token
if self.args.username is None and self.args.repo_id is None:
raise ValueError("Username or repo id must be specified for push to hub")
if self.args.token is None:
raise ValueError("Token must be specified for push to hub")

if self.args.backend.startswith("spaces") or self.args.backend.startswith("ep-"):
if not self.args.push_to_hub:
raise ValueError("Push to hub must be specified for spaces backend")
Expand All @@ -399,7 +419,9 @@ def __init__(self, args):
from autotrain.infer.text_generation import TextGenerationInference

tgi = TextGenerationInference(
self.args.project_name, use_int4=self.args.use_int4, use_int8=self.args.use_int8
self.args.project_name,
use_int4=self.args.use_int4,
use_int8=self.args.use_int8,
)
while True:
prompt = input("User: ")
Expand Down Expand Up @@ -466,6 +488,8 @@ def run(self):
merge_adapter=self.args.merge_adapter,
username=self.args.username,
use_flash_attention_2=self.args.use_flash_attention_2,
rejected_text_column=self.args.rejected_text_column,
disable_gradient_checkpointing=self.args.disable_gradient_checkpointing,
)

# space training
Expand Down Expand Up @@ -494,7 +518,14 @@ def run(self):
if self.num_gpus == 1:
train_llm(params)
else:
cmd = ["accelerate", "launch", "--multi_gpu", "--num_machines", "1", "--num_processes"]
cmd = [
"accelerate",
"launch",
"--multi_gpu",
"--num_machines",
"1",
"--num_processes",
]
cmd.append(str(self.num_gpus))
cmd.append("--mixed_precision")
if self.args.fp16:
Expand Down
167 changes: 133 additions & 34 deletions src/autotrain/trainers/clm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
BitsAndBytesConfig,
Trainer,
TrainingArguments,
default_data_collator,
)
from trl import SFTTrainer
from trl import RewardConfig, RewardTrainer, SFTTrainer

from autotrain import logger
from autotrain.trainers.clm import utils
Expand All @@ -44,10 +45,6 @@ def train(config):
if config.repo_id is None and config.username is not None:
config.repo_id = f"{config.username}/{config.project_name}"

# TODO: remove when SFT is fixed
# if config.trainer == "sft":
# config.trainer = "default"

# check if config.train_split.csv exists in config.data_path
if config.train_split is not None:
train_path = f"{config.data_path}/{config.train_split}.csv"
Expand All @@ -61,6 +58,14 @@ def train(config):
split=config.train_split,
token=config.token,
)
# rename columns for reward trainer
if config.trainer == "reward":
if not (config.text_column == "chosen" and config.text_column in train_data.column_names):
train_data = train_data.rename_column(config.text_column, "chosen")
if not (
config.rejected_text_column == "rejected" and config.rejected_text_column in train_data.column_names
):
train_data = train_data.rename_column(config.rejected_text_column, "rejected")

if config.valid_split is not None:
valid_path = f"{config.data_path}/{config.valid_split}.csv"
Expand All @@ -75,6 +80,14 @@ def train(config):
token=config.token,
)

if config.trainer == "reward":
if not (config.text_column == "chosen" and config.text_column in valid_data.column_names):
valid_data = valid_data.rename_column(config.text_column, "chosen")
if not (
config.rejected_text_column == "rejected" and config.rejected_text_column in valid_data.column_names
):
valid_data = valid_data.rename_column(config.rejected_text_column, "rejected")

tokenizer = AutoTokenizer.from_pretrained(
config.model,
token=config.token,
Expand All @@ -87,6 +100,9 @@ def train(config):
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token

if getattr(tokenizer, "pad_token_id", None) is None:
tokenizer.pad_token_id = tokenizer.eos_token_id

if config.trainer == "default":
train_data = utils.process_data(
data=train_data,
Expand All @@ -105,6 +121,10 @@ def train(config):
token=config.token,
trust_remote_code=True,
)
if config.trainer == "reward":
model_config.num_labels = 1
model_config.pad_token_id = tokenizer.pad_token_id
model_config.pad_token = tokenizer.pad_token

if config.use_peft:
if config.use_int4:
Expand All @@ -121,38 +141,72 @@ def train(config):
else:
bnb_config = None

model = AutoModelForCausalLM.from_pretrained(
config.model,
config=model_config,
token=config.token,
quantization_config=bnb_config,
torch_dtype=torch.float16 if config.fp16 else torch.float32,
device_map={"": Accelerator().process_index} if torch.cuda.is_available() else None,
trust_remote_code=True,
use_flash_attention_2=config.use_flash_attention_2,
)
if config.trainer == "reward":
model = AutoModelForSequenceClassification.from_pretrained(
config.model,
config=model_config,
token=config.token,
quantization_config=bnb_config,
torch_dtype=torch.float16,
device_map={"": Accelerator().process_index} if torch.cuda.is_available() else None,
trust_remote_code=True,
use_flash_attention_2=config.use_flash_attention_2,
)
else:
model = AutoModelForCausalLM.from_pretrained(
config.model,
config=model_config,
token=config.token,
quantization_config=bnb_config,
torch_dtype=torch.float16,
device_map={"": Accelerator().process_index} if torch.cuda.is_available() else None,
trust_remote_code=True,
use_flash_attention_2=config.use_flash_attention_2,
)
else:
model = AutoModelForCausalLM.from_pretrained(
config.model,
config=model_config,
token=config.token,
trust_remote_code=True,
use_flash_attention_2=config.use_flash_attention_2,
)
if config.trainer == "reward":
model = AutoModelForSequenceClassification.from_pretrained(
config.model,
trust_remote_code=True,
num_labels=1,
use_flash_attention_2=config.use_flash_attention_2,
)
else:
model = AutoModelForCausalLM.from_pretrained(
config.model,
config=model_config,
token=config.token,
trust_remote_code=True,
use_flash_attention_2=config.use_flash_attention_2,
)

model.resize_token_embeddings(len(tokenizer))

if config.use_peft:
if config.use_int8 or config.use_int4:
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=utils.get_target_modules(config),
)
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=not config.disable_gradient_checkpointing,
)
if config.trainer == "reward":
peft_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
bias="none",
task_type="SEQ_CLS",
target_modules=utils.get_target_modules(config),
# modules_to_save=["scores"],
)
else:
peft_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=utils.get_target_modules(config),
)
model = get_peft_model(model, peft_config)

if config.block_size == -1:
Expand All @@ -176,6 +230,7 @@ def train(config):
block_size = min(config.block_size, tokenizer.model_max_length)

config.block_size = block_size
logger.info(f"Using block size {block_size}")

if config.trainer == "default":
tokenize_fn = partial(utils.tokenize, tokenizer=tokenizer, config=config)
Expand Down Expand Up @@ -213,6 +268,30 @@ def train(config):
desc=f"Grouping texts in chunks of {block_size}",
)

elif config.trainer == "reward":
reward_proc = partial(utils.preprocess_reward, tokenizer=tokenizer)
train_data = train_data.map(
reward_proc,
batched=True,
num_proc=4,
desc="Running tokenizer on train dataset",
)
train_data = train_data.filter(
lambda x: len(x["input_ids_chosen"]) <= config.block_size
and len(x["input_ids_rejected"]) <= config.block_size
)
if config.valid_split is not None:
valid_data = valid_data.map(
reward_proc,
batched=True,
num_proc=4,
desc="Running tokenizer on validation dataset",
)
valid_data = valid_data.filter(
lambda x: len(x["input_ids_chosen"]) <= config.block_size
and len(x["input_ids_rejected"]) <= config.block_size
)

logger.info("creating trainer")
# trainer specific
if config.logging_steps == -1:
Expand Down Expand Up @@ -248,9 +327,14 @@ def train(config):
push_to_hub=False,
load_best_model_at_end=True if config.valid_split is not None else False,
ddp_find_unused_parameters=False,
gradient_checkpointing=not config.disable_gradient_checkpointing,
)

args = TrainingArguments(**training_args)
if config.trainer == "reward":
training_args["max_length"] = config.block_size
args = RewardConfig(**training_args)
else:
args = TrainingArguments(**training_args)

callbacks = []
if config.use_peft:
Expand Down Expand Up @@ -283,6 +367,14 @@ def train(config):
tokenizer=tokenizer,
packing=True,
)
elif config.trainer == "reward":
trainer = RewardTrainer(
**trainer_args,
train_dataset=train_data,
eval_dataset=valid_data if config.valid_split is not None else None,
peft_config=peft_config,
tokenizer=tokenizer,
)
else:
raise ValueError(f"trainer `{config.trainer}` not supported")
model.config.use_cache = False
Expand Down Expand Up @@ -330,10 +422,17 @@ def train(config):
if os.path.exists(f"{config.project_name}/training_params.json"):
training_params = json.load(open(f"{config.project_name}/training_params.json"))
training_params.pop("token")
json.dump(training_params, open(f"{config.project_name}/training_params.json", "w"))
json.dump(
training_params,
open(f"{config.project_name}/training_params.json", "w"),
)
api = HfApi(token=config.token)
api.create_repo(repo_id=config.repo_id, repo_type="model", private=True)
api.upload_folder(folder_path=config.project_name, repo_id=config.repo_id, repo_type="model")
api.upload_folder(
folder_path=config.project_name,
repo_id=config.repo_id,
repo_type="model",
)

if PartialState().process_index == 0:
if "SPACE_ID" in os.environ:
Expand Down
2 changes: 2 additions & 0 deletions src/autotrain/trainers/clm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class LLMTrainingParams(BaseModel):
train_split: str = Field("train", title="Train data config")
valid_split: str = Field(None, title="Validation data config")
text_column: str = Field("text", title="Text column")
rejected_text_column: str = Field(None, title="Rejected text column")
token: str = Field(None, title="Huggingface token")
lr: float = Field(3e-5, title="Learning rate")
epochs: int = Field(1, title="Number of training epochs")
Expand Down Expand Up @@ -45,6 +46,7 @@ class LLMTrainingParams(BaseModel):
merge_adapter: bool = Field(False, title="Merge adapter")
username: str = Field(None, title="Hugging Face Username")
use_flash_attention_2: bool = Field(False, title="Use flash attention 2")
disable_gradient_checkpointing: bool = Field(True, title="Gradient checkpointing")

def save(self, output_dir):
os.makedirs(output_dir, exist_ok=True)
Expand Down
Loading

0 comments on commit d42c50f

Please sign in to comment.