diff --git a/extensions/thunder/pretrain.py b/extensions/thunder/pretrain.py index 02c1a4ed42..757a3ecb99 100644 --- a/extensions/thunder/pretrain.py +++ b/extensions/thunder/pretrain.py @@ -246,6 +246,7 @@ def fit( tokenizer_dir: Optional[Path], train: TrainArgs, eval: EvalArgs, + optimizer: Union[str, Dict], ) -> None: model = state["model"] optimizer = state["optimizer"]