Skip to content

Commit

Permalink
Merge pull request #200 from noiji/patch-2
Browse files Browse the repository at this point in the history
[LLMLingua2] Save model in huggingface style
  • Loading branch information
QianhuiWu authored Nov 17, 2024
2 parents 2dbdbd3 + a307f36 commit ae9dfbd
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion experiments/llmlingua2/model_training/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# Licensed under The MIT License [see LICENSE for details]

python train_roberta.py --data_path ../../../results/meetingbank/gpt-4-32k_comp/annotation_kept_cs512_meetingbank_train_formated.pt \
--save_path ../../../results/models/xlm_roberta_large_meetingbank_only.pth
--save_path ../../../results/models/xlm_roberta_large_meetingbank_only
7 changes: 5 additions & 2 deletions experiments/llmlingua2/model_training/train_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
parser.add_argument(
"--save_path",
help="save path",
default="../../../results/models/xlm_roberta_large_meetingbank_only.pth",
default="../../../results/models/xlm_roberta_large_meetingbank_only",
)
parser.add_argument("--lr", help="learning rate", default=1e-5, type=float)
parser.add_argument(
Expand Down Expand Up @@ -218,10 +218,13 @@ def test(model, eval_dataloader):
optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)

best_acc = 0

for epoch in tqdm(range(args.num_epoch)):
print(f"Training epoch: {epoch + 1}")
train(epoch)
acc = test(model, val_dataloader)
if acc > best_acc:
best_acc = acc
torch.save(model.state_dict(), args.save_path)
torch.save(model.state_dict(), f"{args.save_path}/state_dict.pth")
model.save_pretrained(args.save_path)
tokenizer.save_pretrained(args.save_path)

0 comments on commit ae9dfbd

Please sign in to comment.