diff --git a/litgpt/pretrain.py b/litgpt/pretrain.py index eef48bdc0c..4ab31b414e 100644 --- a/litgpt/pretrain.py +++ b/litgpt/pretrain.py @@ -199,6 +199,10 @@ def main( train_time = time.perf_counter() fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval) + + # Save final checkpoint + save_checkpoint(fabric, state, tokenizer_dir, out_dir / "final" / "lit_model.pth") + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") @@ -325,15 +329,7 @@ def fit( fabric.barrier() if train.save_interval is not None and not is_accumulating and state["step_count"] % train.save_interval == 0: - checkpoint_file = out_dir / f"step-{state['step_count']:08d}" / "lit_model.pth" - checkpoint_file.parent.mkdir(parents=True, exist_ok=True) - fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}") - fabric.save(checkpoint_file, state) - if fabric.global_rank == 0: - save_hyperparameters(setup, checkpoint_file.parent) - if tokenizer_dir is not None: - copy_config_files(tokenizer_dir, checkpoint_file.parent) - save_config(model.config, checkpoint_file.parent) + save_checkpoint(fabric, state, tokenizer_dir, out_dir / f"step-{state['step_count']:08d}" / "lit_model.pth") @torch.no_grad() @@ -413,6 +409,18 @@ def init_out_dir(out_dir: Path) -> Path: return out_dir +def save_checkpoint(fabric, state, tokenizer_dir, checkpoint_file): + model = state["model"] + checkpoint_file.parent.mkdir(parents=True, exist_ok=True) + fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}") + fabric.save(checkpoint_file, state) + if fabric.global_rank == 0: + save_hyperparameters(setup, checkpoint_file.parent) + if tokenizer_dir is not None: + copy_config_files(tokenizer_dir, checkpoint_file.parent) + save_config(model.config, checkpoint_file.parent) + + def validate_args(train: TrainArgs, eval: EvalArgs, initial_checkpoint_dir, resume) -> None: issues = [] unsupported = [(train, ["max_steps", "epochs"]), (eval, ["max_new_tokens"])]