diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index ca4ca08cab08..455022b1ba44 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -17,7 +17,7 @@ import shutil from datetime import timedelta from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Literal, Optional, Union import lightning.pytorch import torch @@ -63,7 +63,7 @@ def __init__( self, monitor: Optional[str] = "val_loss", verbose: bool = True, - save_last: Optional[bool] = True, + save_last: Optional[Union[bool, Literal["link"]]] = True, save_top_k: int = 3, save_weights_only: bool = False, ## TODO: check support mode: str = "min",