Skip to content

Commit

Permalink
add full FT save
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Mar 27, 2024
1 parent d59964c commit 3d12df4
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions wtpsplit/train/train_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,16 +567,20 @@ def compute_metrics(trainer):
**kwargs,
)
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
with training_args.main_process_first():
logger.warning(f"Finished training for {lang} {dataset_name}.")
if training_args.local_rank == 0:
if not os.path.exists(os.path.join(training_args.output_dir, dataset_name, lang)):
os.makedirs(os.path.join(training_args.output_dir, dataset_name, lang))
save_model = copy.deepcopy(model.backbone)
save_model = save_model.to("cpu")
save_model.to("cpu").save_adapter(
adapter_name="text",
save_directory=os.path.join(training_args.output_dir, dataset_name, lang),
with_head=True,
)
if adapter_args.train_adapter:
save_model.to("cpu").save_adapter(
adapter_name="text",
save_directory=os.path.join(training_args.output_dir, dataset_name, lang),
with_head=True,
)
else:
save_model.save_pretrained(os.path.join(training_args.output_dir, dataset_name, lang))
if training_args.local_rank == 0:
# eval here within 1 go
cmd = ""
Expand Down

0 comments on commit 3d12df4

Please sign in to comment.