Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed May 23, 2024
1 parent d882a8f commit 5811279
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,16 +340,13 @@ def test_instantiate_bnb_optimizer_with_invalid_str(model_parameters):


def test_instantiate_torch_optimizer_with_str(model_parameters):
with mock.patch("litgpt.utils.instantiate_class") as mock_instantiate_class:
mock_instantiate_class.return_value = torch.optim.Adam(model_parameters, lr=0.01)
optimizer = instantiate_torch_optimizer("Adam", model_parameters, lr=0.01)
assert isinstance(optimizer, torch.optim.Adam)
assert optimizer.param_groups[0]["lr"] == 0.01
optimizer = instantiate_torch_optimizer("Adam", model_parameters, lr=0.01)
assert isinstance(optimizer, torch.optim.Adam)
assert optimizer.param_groups[0]["lr"] == 0.01


def test_instantiate_torch_optimizer_with_class(model_parameters):
with mock.patch("litgpt.utils.instantiate_class") as mock_instantiate_class:
mock_instantiate_class.return_value = torch.optim.Adam(model_parameters, lr=0.02)
optimizer = instantiate_torch_optimizer(torch.optim.Adam, model_parameters, lr=0.02)
assert isinstance(optimizer, torch.optim.Adam)
assert optimizer.param_groups[0]["lr"] == 0.02
optimizer = instantiate_torch_optimizer({"class_path": "torch.optim.Adam", "init_args": {"lr": 123}}, model_parameters, lr=0.02)
assert isinstance(optimizer, torch.optim.Adam)
# init args gets overridden
assert optimizer.param_groups[0]["lr"] == 0.02

0 comments on commit 5811279

Please sign in to comment.