Skip to content

Commit

Permalink
use sftconfig + update requirements (#729)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur authored Aug 13, 2024
1 parent 7587437 commit 79f3421
Showing 2 changed files with 9 additions and 10 deletions.
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -26,14 +26,14 @@ cryptography==42.0.5
nvitop==1.3.2
# latest versions
tensorboard==2.16.2
peft==0.11.1
peft==0.12.0
trl==0.9.6
tiktoken==0.6.0
transformers==4.43.1
transformers==4.44.0

accelerate==0.32.0
accelerate==0.33.0
diffusers==0.27.2
bitsandbytes==0.43.2
bitsandbytes==0.43.3
# extras
rouge_score==0.1.2
py7zr==0.21.0
11 changes: 5 additions & 6 deletions src/autotrain/trainers/clm/train_clm_sft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from peft import LoraConfig
from transformers import TrainingArguments
from transformers.trainer_callback import PrinterCallback
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer

from autotrain import logger
from autotrain.trainers.clm import utils
@@ -20,7 +19,10 @@ def train(config):
training_args = utils.configure_training_args(config, logging_steps)
config = utils.configure_block_size(config, tokenizer)

args = TrainingArguments(**training_args)
training_args["dataset_text_field"] = config.text_column
training_args["max_seq_length"] = config.block_size
training_args["packing"] = True
args = SFTConfig(**training_args)

model = utils.get_model(config, tokenizer)

@@ -46,10 +48,7 @@ def train(config):
train_dataset=train_data,
eval_dataset=valid_data if config.valid_split is not None else None,
peft_config=peft_config if config.peft else None,
dataset_text_field=config.text_column,
max_seq_length=config.block_size,
tokenizer=tokenizer,
packing=True,
)

trainer.remove_callback(PrinterCallback)

0 comments on commit 79f3421

Please sign in to comment.