diff --git a/benchmarks/imagenet/resnet50/main.py b/benchmarks/imagenet/resnet50/main.py index ee6739520..be4f08b89 100644 --- a/benchmarks/imagenet/resnet50/main.py +++ b/benchmarks/imagenet/resnet50/main.py @@ -39,6 +39,7 @@ parser.add_argument("--accelerator", type=str, default="gpu") parser.add_argument("--devices", type=int, default=1) parser.add_argument("--precision", type=str, default="16-mixed") +parser.add_argument("--ckpt-path", type=Path, default=None) parser.add_argument("--compile-model", action="store_true") parser.add_argument("--methods", type=str, nargs="+") parser.add_argument("--num-classes", type=int, default=1000) @@ -76,6 +77,7 @@ def main( skip_knn_eval: bool, skip_linear_eval: bool, skip_finetune_eval: bool, + ckpt_path: Union[Path, None], ) -> None: torch.set_float32_matmul_precision("high") @@ -96,6 +98,8 @@ def main( if epochs <= 0: print_rank_zero("Epochs <= 0, skipping pretraining.") + if ckpt_path is not None: + model.load_state_dict(torch.load(ckpt_path)["state_dict"]) else: pretrain( model=model, @@ -109,6 +113,7 @@ def main( accelerator=accelerator, devices=devices, precision=precision, + ckpt_path=ckpt_path, ) if skip_knn_eval: @@ -171,6 +176,7 @@ def pretrain( accelerator: str, devices: int, precision: str, + ckpt_path: Union[Path, None], ) -> None: print_rank_zero(f"Running pretraining for {method}...") @@ -222,10 +228,12 @@ def pretrain( strategy="ddp_find_unused_parameters_true", sync_batchnorm=True, ) + trainer.fit( model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, + ckpt_path=ckpt_path, ) for metric in ["val_online_cls_top1", "val_online_cls_top5"]: print_rank_zero(f"max {metric}: {max(metric_callback.val_metrics[metric])}")