diff --git a/litgpt/utils.py b/litgpt/utils.py index 9225af8911..eeb73c1320 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -514,5 +514,7 @@ def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs): optimizer_cls = getattr(torch.optim, optimizer) optimizer = optimizer_cls(model_parameters, **kwargs) else: - optimizer = instantiate_class(model_parameters, optimizer, **kwargs) + optimizer = dict(optimizer) # copy + optimizer["init_args"].update(kwargs) + optimizer = instantiate_class(model_parameters, optimizer) return optimizer