From d42c50f002a0a54b07bdec83cbba6f964801e299 Mon Sep 17 00:00:00 2001 From: abhishek thakur <1183441+abhishekkrthakur@users.noreply.github.com> Date: Thu, 12 Oct 2023 14:56:16 +0200 Subject: [PATCH] Reward modelling (#297) --- .gitignore | 1 + src/autotrain/cli/run_llm.py | 39 +++++- src/autotrain/trainers/clm/__main__.py | 167 ++++++++++++++++++++----- src/autotrain/trainers/clm/params.py | 2 + src/autotrain/trainers/clm/utils.py | 19 +++ 5 files changed, 190 insertions(+), 38 deletions(-) diff --git a/.gitignore b/.gitignore index 9e10756026..1fc6dd0f60 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ test.py output/ output2/ logs/ +op_*/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/src/autotrain/cli/run_llm.py b/src/autotrain/cli/run_llm.py index d76d28c0b7..a1306b9900 100644 --- a/src/autotrain/cli/run_llm.py +++ b/src/autotrain/cli/run_llm.py @@ -67,6 +67,14 @@ def register_subcommand(parser: ArgumentParser): "default": "text", "alias": ["--text-column"], }, + { + "arg": "--rejected_text_column", + "help": "Rejected text column to use", + "required": False, + "type": str, + "default": "rejected", + "alias": ["--rejected-text-column"], + }, { "arg": "--model", "help": "Model to use", @@ -332,6 +340,13 @@ def register_subcommand(parser: ArgumentParser): "action": "store_true", "alias": ["--use-flash-attention-2", "--use-fa2"], }, + { + "arg": "--disable_gradient_checkpointing", + "help": "Disable gradient checkpointing", + "required": False, + "action": "store_true", + "alias": ["--disable-gradient-checkpointing", "--disable-gc"], + }, ] run_llm_parser = parser.add_parser("llm", description="✨ Run AutoTrain LLM") for arg in arg_list: @@ -372,6 +387,7 @@ def __init__(self, args): "use_int4", "merge_adapter", "use_flash_attention_2", + "disable_gradient_checkpointing", ] for arg_name in store_true_arg_names: if getattr(self.args, arg_name) is None: @@ -385,8 +401,12 @@ def __init__(self, args): if self.args.model is None: raise ValueError("Model must be specified") if self.args.push_to_hub: - if self.args.repo_id is None: - raise ValueError("Repo id must be specified for push to hub") + # must have project_name, username and token OR project_name, repo_id, token + if self.args.username is None and self.args.repo_id is None: + raise ValueError("Username or repo id must be specified for push to hub") + if self.args.token is None: + raise ValueError("Token must be specified for push to hub") + if self.args.backend.startswith("spaces") or self.args.backend.startswith("ep-"): if not self.args.push_to_hub: raise ValueError("Push to hub must be specified for spaces backend") @@ -399,7 +419,9 @@ def __init__(self, args): from autotrain.infer.text_generation import TextGenerationInference tgi = TextGenerationInference( - self.args.project_name, use_int4=self.args.use_int4, use_int8=self.args.use_int8 + self.args.project_name, + use_int4=self.args.use_int4, + use_int8=self.args.use_int8, ) while True: prompt = input("User: ") @@ -466,6 +488,8 @@ def run(self): merge_adapter=self.args.merge_adapter, username=self.args.username, use_flash_attention_2=self.args.use_flash_attention_2, + rejected_text_column=self.args.rejected_text_column, + disable_gradient_checkpointing=self.args.disable_gradient_checkpointing, ) # space training @@ -494,7 +518,14 @@ def run(self): if self.num_gpus == 1: train_llm(params) else: - cmd = ["accelerate", "launch", "--multi_gpu", "--num_machines", "1", "--num_processes"] + cmd = [ + "accelerate", + "launch", + "--multi_gpu", + "--num_machines", + "1", + "--num_processes", + ] cmd.append(str(self.num_gpus)) cmd.append("--mixed_precision") if self.args.fp16: diff --git a/src/autotrain/trainers/clm/__main__.py b/src/autotrain/trainers/clm/__main__.py index 177142e690..72379d337f 100644 --- a/src/autotrain/trainers/clm/__main__.py +++ b/src/autotrain/trainers/clm/__main__.py @@ -14,13 +14,14 @@ from transformers import ( AutoConfig, AutoModelForCausalLM, + AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig, Trainer, TrainingArguments, default_data_collator, ) -from trl import SFTTrainer +from trl import RewardConfig, RewardTrainer, SFTTrainer from autotrain import logger from autotrain.trainers.clm import utils @@ -44,10 +45,6 @@ def train(config): if config.repo_id is None and config.username is not None: config.repo_id = f"{config.username}/{config.project_name}" - # TODO: remove when SFT is fixed - # if config.trainer == "sft": - # config.trainer = "default" - # check if config.train_split.csv exists in config.data_path if config.train_split is not None: train_path = f"{config.data_path}/{config.train_split}.csv" @@ -61,6 +58,14 @@ def train(config): split=config.train_split, token=config.token, ) + # rename columns for reward trainer + if config.trainer == "reward": + if not (config.text_column == "chosen" and config.text_column in train_data.column_names): + train_data = train_data.rename_column(config.text_column, "chosen") + if not ( + config.rejected_text_column == "rejected" and config.rejected_text_column in train_data.column_names + ): + train_data = train_data.rename_column(config.rejected_text_column, "rejected") if config.valid_split is not None: valid_path = f"{config.data_path}/{config.valid_split}.csv" @@ -75,6 +80,14 @@ def train(config): token=config.token, ) + if config.trainer == "reward": + if not (config.text_column == "chosen" and config.text_column in valid_data.column_names): + valid_data = valid_data.rename_column(config.text_column, "chosen") + if not ( + config.rejected_text_column == "rejected" and config.rejected_text_column in valid_data.column_names + ): + valid_data = valid_data.rename_column(config.rejected_text_column, "rejected") + tokenizer = AutoTokenizer.from_pretrained( config.model, token=config.token, @@ -87,6 +100,9 @@ def train(config): if getattr(tokenizer, "pad_token", None) is None: tokenizer.pad_token = tokenizer.eos_token + if getattr(tokenizer, "pad_token_id", None) is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + if config.trainer == "default": train_data = utils.process_data( data=train_data, @@ -105,6 +121,10 @@ def train(config): token=config.token, trust_remote_code=True, ) + if config.trainer == "reward": + model_config.num_labels = 1 + model_config.pad_token_id = tokenizer.pad_token_id + model_config.pad_token = tokenizer.pad_token if config.use_peft: if config.use_int4: @@ -121,38 +141,72 @@ def train(config): else: bnb_config = None - model = AutoModelForCausalLM.from_pretrained( - config.model, - config=model_config, - token=config.token, - quantization_config=bnb_config, - torch_dtype=torch.float16 if config.fp16 else torch.float32, - device_map={"": Accelerator().process_index} if torch.cuda.is_available() else None, - trust_remote_code=True, - use_flash_attention_2=config.use_flash_attention_2, - ) + if config.trainer == "reward": + model = AutoModelForSequenceClassification.from_pretrained( + config.model, + config=model_config, + token=config.token, + quantization_config=bnb_config, + torch_dtype=torch.float16, + device_map={"": Accelerator().process_index} if torch.cuda.is_available() else None, + trust_remote_code=True, + use_flash_attention_2=config.use_flash_attention_2, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + config.model, + config=model_config, + token=config.token, + quantization_config=bnb_config, + torch_dtype=torch.float16, + device_map={"": Accelerator().process_index} if torch.cuda.is_available() else None, + trust_remote_code=True, + use_flash_attention_2=config.use_flash_attention_2, + ) else: - model = AutoModelForCausalLM.from_pretrained( - config.model, - config=model_config, - token=config.token, - trust_remote_code=True, - use_flash_attention_2=config.use_flash_attention_2, - ) + if config.trainer == "reward": + model = AutoModelForSequenceClassification.from_pretrained( + config.model, + trust_remote_code=True, + num_labels=1, + use_flash_attention_2=config.use_flash_attention_2, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + config.model, + config=model_config, + token=config.token, + trust_remote_code=True, + use_flash_attention_2=config.use_flash_attention_2, + ) model.resize_token_embeddings(len(tokenizer)) if config.use_peft: if config.use_int8 or config.use_int4: - model = prepare_model_for_kbit_training(model) - peft_config = LoraConfig( - r=config.lora_r, - lora_alpha=config.lora_alpha, - lora_dropout=config.lora_dropout, - bias="none", - task_type="CAUSAL_LM", - target_modules=utils.get_target_modules(config), - ) + model = prepare_model_for_kbit_training( + model, + use_gradient_checkpointing=not config.disable_gradient_checkpointing, + ) + if config.trainer == "reward": + peft_config = LoraConfig( + r=config.lora_r, + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + bias="none", + task_type="SEQ_CLS", + target_modules=utils.get_target_modules(config), + # modules_to_save=["scores"], + ) + else: + peft_config = LoraConfig( + r=config.lora_r, + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + bias="none", + task_type="CAUSAL_LM", + target_modules=utils.get_target_modules(config), + ) model = get_peft_model(model, peft_config) if config.block_size == -1: @@ -176,6 +230,7 @@ def train(config): block_size = min(config.block_size, tokenizer.model_max_length) config.block_size = block_size + logger.info(f"Using block size {block_size}") if config.trainer == "default": tokenize_fn = partial(utils.tokenize, tokenizer=tokenizer, config=config) @@ -213,6 +268,30 @@ def train(config): desc=f"Grouping texts in chunks of {block_size}", ) + elif config.trainer == "reward": + reward_proc = partial(utils.preprocess_reward, tokenizer=tokenizer) + train_data = train_data.map( + reward_proc, + batched=True, + num_proc=4, + desc="Running tokenizer on train dataset", + ) + train_data = train_data.filter( + lambda x: len(x["input_ids_chosen"]) <= config.block_size + and len(x["input_ids_rejected"]) <= config.block_size + ) + if config.valid_split is not None: + valid_data = valid_data.map( + reward_proc, + batched=True, + num_proc=4, + desc="Running tokenizer on validation dataset", + ) + valid_data = valid_data.filter( + lambda x: len(x["input_ids_chosen"]) <= config.block_size + and len(x["input_ids_rejected"]) <= config.block_size + ) + logger.info("creating trainer") # trainer specific if config.logging_steps == -1: @@ -248,9 +327,14 @@ def train(config): push_to_hub=False, load_best_model_at_end=True if config.valid_split is not None else False, ddp_find_unused_parameters=False, + gradient_checkpointing=not config.disable_gradient_checkpointing, ) - args = TrainingArguments(**training_args) + if config.trainer == "reward": + training_args["max_length"] = config.block_size + args = RewardConfig(**training_args) + else: + args = TrainingArguments(**training_args) callbacks = [] if config.use_peft: @@ -283,6 +367,14 @@ def train(config): tokenizer=tokenizer, packing=True, ) + elif config.trainer == "reward": + trainer = RewardTrainer( + **trainer_args, + train_dataset=train_data, + eval_dataset=valid_data if config.valid_split is not None else None, + peft_config=peft_config, + tokenizer=tokenizer, + ) else: raise ValueError(f"trainer `{config.trainer}` not supported") model.config.use_cache = False @@ -330,10 +422,17 @@ def train(config): if os.path.exists(f"{config.project_name}/training_params.json"): training_params = json.load(open(f"{config.project_name}/training_params.json")) training_params.pop("token") - json.dump(training_params, open(f"{config.project_name}/training_params.json", "w")) + json.dump( + training_params, + open(f"{config.project_name}/training_params.json", "w"), + ) api = HfApi(token=config.token) api.create_repo(repo_id=config.repo_id, repo_type="model", private=True) - api.upload_folder(folder_path=config.project_name, repo_id=config.repo_id, repo_type="model") + api.upload_folder( + folder_path=config.project_name, + repo_id=config.repo_id, + repo_type="model", + ) if PartialState().process_index == 0: if "SPACE_ID" in os.environ: diff --git a/src/autotrain/trainers/clm/params.py b/src/autotrain/trainers/clm/params.py index c3c8054831..8e61015266 100644 --- a/src/autotrain/trainers/clm/params.py +++ b/src/autotrain/trainers/clm/params.py @@ -12,6 +12,7 @@ class LLMTrainingParams(BaseModel): train_split: str = Field("train", title="Train data config") valid_split: str = Field(None, title="Validation data config") text_column: str = Field("text", title="Text column") + rejected_text_column: str = Field(None, title="Rejected text column") token: str = Field(None, title="Huggingface token") lr: float = Field(3e-5, title="Learning rate") epochs: int = Field(1, title="Number of training epochs") @@ -45,6 +46,7 @@ class LLMTrainingParams(BaseModel): merge_adapter: bool = Field(False, title="Merge adapter") username: str = Field(None, title="Hugging Face Username") use_flash_attention_2: bool = Field(False, title="Use flash attention 2") + disable_gradient_checkpointing: bool = Field(True, title="Gradient checkpointing") def save(self, output_dir): os.makedirs(output_dir, exist_ok=True) diff --git a/src/autotrain/trainers/clm/utils.py b/src/autotrain/trainers/clm/utils.py index 7e99e8953d..9f7bbdc7bf 100644 --- a/src/autotrain/trainers/clm/utils.py +++ b/src/autotrain/trainers/clm/utils.py @@ -32,6 +32,25 @@ """ +def preprocess_reward(examples, tokenizer): + new_examples = { + "input_ids_chosen": [], + "attention_mask_chosen": [], + "input_ids_rejected": [], + "attention_mask_rejected": [], + } + for chosen, rejected in zip(examples["chosen"], examples["rejected"]): + tokenized_chosen = tokenizer(chosen, truncation=True) + tokenized_rejected = tokenizer(rejected, truncation=True) + + new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) + new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) + new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) + new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) + + return new_examples + + def get_target_modules(config): if config.target_modules is None: return TARGET_MODULES.get(config.model)