From 79f3421de9761b405f87299edd60ed23b8ad112f Mon Sep 17 00:00:00 2001 From: abhishek thakur <1183441+abhishekkrthakur@users.noreply.github.com> Date: Tue, 13 Aug 2024 14:29:20 +0200 Subject: [PATCH] use sftconfig + update requirements (#729) --- requirements.txt | 8 ++++---- src/autotrain/trainers/clm/train_clm_sft.py | 11 +++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/requirements.txt b/requirements.txt index dff4537d94..c909eac288 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,14 +26,14 @@ cryptography==42.0.5 nvitop==1.3.2 # latest versions tensorboard==2.16.2 -peft==0.11.1 +peft==0.12.0 trl==0.9.6 tiktoken==0.6.0 -transformers==4.43.1 +transformers==4.44.0 -accelerate==0.32.0 +accelerate==0.33.0 diffusers==0.27.2 -bitsandbytes==0.43.2 +bitsandbytes==0.43.3 # extras rouge_score==0.1.2 py7zr==0.21.0 diff --git a/src/autotrain/trainers/clm/train_clm_sft.py b/src/autotrain/trainers/clm/train_clm_sft.py index 35e8937a45..d00eb6cbec 100644 --- a/src/autotrain/trainers/clm/train_clm_sft.py +++ b/src/autotrain/trainers/clm/train_clm_sft.py @@ -1,7 +1,6 @@ from peft import LoraConfig -from transformers import TrainingArguments from transformers.trainer_callback import PrinterCallback -from trl import SFTTrainer +from trl import SFTConfig, SFTTrainer from autotrain import logger from autotrain.trainers.clm import utils @@ -20,7 +19,10 @@ def train(config): training_args = utils.configure_training_args(config, logging_steps) config = utils.configure_block_size(config, tokenizer) - args = TrainingArguments(**training_args) + training_args["dataset_text_field"] = config.text_column + training_args["max_seq_length"] = config.block_size + training_args["packing"] = True + args = SFTConfig(**training_args) model = utils.get_model(config, tokenizer) @@ -46,10 +48,7 @@ def train(config): train_dataset=train_data, eval_dataset=valid_data if config.valid_split is not None else None, peft_config=peft_config if config.peft else None, - dataset_text_field=config.text_column, - max_seq_length=config.block_size, tokenizer=tokenizer, - packing=True, ) trainer.remove_callback(PrinterCallback)