diff --git a/litgpt/utils.py b/litgpt/utils.py index 69a12b21f3..3ee720a1aa 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -405,16 +405,17 @@ def CLI(*args: Any, **kwargs: Any) -> Any: return CLI(*args, **kwargs) -def capture_hparams(): +def capture_hparams() -> Dict[str, Any]: """Captures the local variables ('hyperparameters') from where this function gets called.""" caller_frame = inspect.currentframe().f_back locals_of_caller = caller_frame.f_locals hparams = {} for name, value in locals_of_caller.items(): - if isinstance(value, (int, float, str, bool, Path)): + if value is None or isinstance(value, (int, float, str, bool, Path)): hparams[name] = value if is_dataclass(value): hparams[name] = asdict(value) + return hparams def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None: