Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Apr 4, 2024
1 parent c634a28 commit 2c4d7b5
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ def main(

train_time = time.perf_counter()
fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)

# Save final checkpoint
save_checkpoint(fabric, state, tokenizer_dir, out_dir / "final" / "lit_model.pth")

fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
Expand Down Expand Up @@ -325,15 +329,7 @@ def fit(
fabric.barrier()

if train.save_interval is not None and not is_accumulating and state["step_count"] % train.save_interval == 0:
checkpoint_file = out_dir / f"step-{state['step_count']:08d}" / "lit_model.pth"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}")
fabric.save(checkpoint_file, state)
if fabric.global_rank == 0:
save_hyperparameters(setup, checkpoint_file.parent)
if tokenizer_dir is not None:
copy_config_files(tokenizer_dir, checkpoint_file.parent)
save_config(model.config, checkpoint_file.parent)
save_checkpoint(fabric, state, tokenizer_dir, out_dir / f"step-{state['step_count']:08d}" / "lit_model.pth")


@torch.no_grad()
Expand Down Expand Up @@ -413,6 +409,18 @@ def init_out_dir(out_dir: Path) -> Path:
return out_dir


def save_checkpoint(fabric, state, tokenizer_dir, checkpoint_file):
model = state["model"]
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}")
fabric.save(checkpoint_file, state)
if fabric.global_rank == 0:
save_hyperparameters(setup, checkpoint_file.parent)
if tokenizer_dir is not None:
copy_config_files(tokenizer_dir, checkpoint_file.parent)
save_config(model.config, checkpoint_file.parent)


def validate_args(train: TrainArgs, eval: EvalArgs, initial_checkpoint_dir, resume) -> None:
issues = []
unsupported = [(train, ["max_steps", "epochs"]), (eval, ["max_new_tokens"])]
Expand Down

0 comments on commit 2c4d7b5

Please sign in to comment.