From d882a8f3b7c84742bbb9949f544c2c902b84fd76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 23 May 2024 13:04:40 -0400 Subject: [PATCH] Fix init --- litgpt/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litgpt/utils.py b/litgpt/utils.py index 9225af8911..eeb73c1320 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -514,5 +514,7 @@ def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs): optimizer_cls = getattr(torch.optim, optimizer) optimizer = optimizer_cls(model_parameters, **kwargs) else: - optimizer = instantiate_class(model_parameters, optimizer, **kwargs) + optimizer = dict(optimizer) # copy + optimizer["init_args"].update(kwargs) + optimizer = instantiate_class(model_parameters, optimizer) return optimizer