Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added provision for resuming training from a checkpoint in case the training is interrupted #1381

Merged
merged 12 commits into from
Sep 6, 2023
8 changes: 8 additions & 0 deletions benchmarks/imagenet/resnet50/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
parser.add_argument("--accelerator", type=str, default="gpu")
parser.add_argument("--devices", type=int, default=1)
parser.add_argument("--precision", type=str, default="16-mixed")
parser.add_argument("--ckpt-path", type=Path, default=None)
parser.add_argument("--compile-model", action="store_true")
parser.add_argument("--methods", type=str, nargs="+")
parser.add_argument("--num-classes", type=int, default=1000)
Expand Down Expand Up @@ -76,6 +77,7 @@ def main(
skip_knn_eval: bool,
skip_linear_eval: bool,
skip_finetune_eval: bool,
ckpt_path: Union[Path, None],
) -> None:
torch.set_float32_matmul_precision("high")

Expand All @@ -96,6 +98,8 @@ def main(

if epochs <= 0:
print_rank_zero("Epochs <= 0, skipping pretraining.")
if ckpt_path is not None:
model.load_state_dict(torch.load(ckpt_path)["state_dict"])
else:
pretrain(
model=model,
Expand All @@ -109,6 +113,7 @@ def main(
accelerator=accelerator,
devices=devices,
precision=precision,
ckpt_path=ckpt_path,
)

if skip_knn_eval:
Expand Down Expand Up @@ -171,6 +176,7 @@ def pretrain(
accelerator: str,
devices: int,
precision: str,
ckpt_path: Union[Path, None],
) -> None:
print_rank_zero(f"Running pretraining for {method}...")

Expand Down Expand Up @@ -222,10 +228,12 @@ def pretrain(
strategy="ddp_find_unused_parameters_true",
sync_batchnorm=True,
)

trainer.fit(
model=model,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
ckpt_path=ckpt_path,
)
for metric in ["val_online_cls_top1", "val_online_cls_top5"]:
print_rank_zero(f"max {metric}: {max(metric_callback.val_metrics[metric])}")
Expand Down
Loading