Skip to content

Commit

Permalink
linted
Browse files Browse the repository at this point in the history
  • Loading branch information
youssef62 committed Nov 23, 2024
1 parent 2e5972f commit 6f17d10
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 84 deletions.
163 changes: 84 additions & 79 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
import math
import tempfile
import unittest

from copy import deepcopy
from typing import Any, Dict, Tuple
from unittest.mock import patch

from optim.test_lrscheduler import TestLRScheduler # noqa: F401
from optim.test_optim import TestDifferentiableOptimizer # noqa: F401
from optim.test_swa_utils import TestSWAUtils # noqa: F401
from torch.profiler import profile, ProfilerActivity

import torch
from torch.nn import Parameter
Expand All @@ -21,6 +19,7 @@
register_optimizer_step_post_hook,
register_optimizer_step_pre_hook,
)
from torch.profiler import profile, ProfilerActivity
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
Expand Down Expand Up @@ -2166,7 +2165,10 @@ def test_non_empty_state(self, device, dtype, optim_info):
for state in optim.state.values():
self.assertGreater(len(state), 0)

@optims([optim for optim in optim_db if optim.optim_cls.__name__ == "Adam"], dtypes=[torch.float32])
@optims(
[optim for optim in optim_db if optim.optim_cls.__name__ == "Adam"],
dtypes=[torch.float32],
)
def test_less_mem_beta1_zero_adam(self, device, dtype, optim_info):
# test that beta1=0.0 uses less memory than the default
model = torch.nn.Linear(5, 5)
Expand All @@ -2176,56 +2178,74 @@ def test_less_mem_beta1_zero_adam(self, device, dtype, optim_info):
optim_inputs = optim_info.optim_inputs_func(device=device)
case_to_mem_usage = {}
for optim_input in optim_inputs:
beta1_zero = ("betas" in optim_input.kwargs and optim_input.kwargs["betas"][0] == 0.0)
if beta1_zero or optim_input.desc == "default" :

activities = [ProfilerActivity.CUDA] if device == "cuda" else [ProfilerActivity.CPU]
beta1_zero = (
"betas" in optim_input.kwargs and optim_input.kwargs["betas"][0] == 0.0
)
if beta1_zero or optim_input.desc == "default":
activities = (
[ProfilerActivity.CUDA]
if device == "cuda"
else [ProfilerActivity.CPU]
)

with profile(
activities=activities,
profile_memory=True,
record_shapes=True
activities=activities, profile_memory=True, record_shapes=True
) as prof:
optim = optim_info.optim_cls(model.parameters(), **optim_input.kwargs)
optim = optim_info.optim_cls(
model.parameters(), **optim_input.kwargs
)
optim.zero_grad()
output = model(inpt)
loss = output.sum()
loss.backward()

optim.step()

case_to_mem_usage["beta1-zero" if beta1_zero else "default"] = sum([
item.cuda_memory_usage if device == "cuda" else item.cpu_memory_usage
for item in prof.key_averages()
])


self.assertGreater(case_to_mem_usage["default"], case_to_mem_usage["beta1-zero"])

case_to_mem_usage["beta1-zero" if beta1_zero else "default"] = sum(
[
item.cuda_memory_usage
if device == "cuda"
else item.cpu_memory_usage
for item in prof.key_averages()
]
)

@optims([optim for optim in optim_db if optim.optim_cls.__name__ == "Adam"], dtypes=[torch.float32])
def test_beta1zero_then_loadstate_beta1nonzero_adam(self, device, dtype, optim_info):
self.assertGreater(
case_to_mem_usage["default"], case_to_mem_usage["beta1-zero"]
)

@optims(
[optim for optim in optim_db if optim.optim_cls.__name__ == "Adam"],
dtypes=[torch.float32],
)
def test_beta1zero_then_loadstate_beta1nonzero_adam(
self, device, dtype, optim_info
):
inpt = torch.ones((5,), dtype=dtype, device=device)
optim_inputs = optim_info.optim_inputs_func(device=device)


def init_model_and_optim(optim_input):
model = torch.nn.Linear(5, 5)
model.to(dtype=dtype, device=device)
optim = optim_info.optim_cls(model.parameters(), **optim_input.kwargs)
return model, optim


# model and optimizer corresponding to beta1 = 0.0
model_beta1_zero, optim_beta1_zero = init_model_and_optim(
[optim_input for optim_input in optim_inputs
if "betas" in optim_input.kwargs and
optim_input.kwargs["betas"][0] == 0.0][0]
[
optim_input
for optim_input in optim_inputs
if "betas" in optim_input.kwargs
and optim_input.kwargs["betas"][0] == 0.0
][0]
)
# model and optimizer corresponding to default params
model_default, optim_default = init_model_and_optim(
[optim_input for optim_input in optim_inputs if optim_input.desc == "default"][0]
[
optim_input
for optim_input in optim_inputs
if optim_input.desc == "default"
][0]
)

# we should receive the same output if we do the following to our models :
Expand All @@ -2244,12 +2264,7 @@ def iteration(model, optim, inpt, n_iters):
return output

# 1. train for n_iters
iteration(
model_beta1_zero,
optim_beta1_zero,
inpt,
n_iters
)
iteration(model_beta1_zero, optim_beta1_zero, inpt, n_iters)

# 2. load state
optim_beta1_zero.load_state_dict(optim_default.state_dict())
Expand All @@ -2258,28 +2273,25 @@ def iteration(model, optim, inpt, n_iters):
self.assertEqual(model_beta1_zero.state_dict(), model_default.state_dict())
self.assertEqual(optim_beta1_zero.state_dict(), optim_default.state_dict())


inpt = torch.rand((5,), dtype=dtype, device=device)
# 3. train for n_iters
output_beta1_zero_after_load = iteration(
model_beta1_zero,
optim_beta1_zero,
inpt,
n_iters
model_beta1_zero, optim_beta1_zero, inpt, n_iters
)
# 4. train for n_iters
output_default = iteration(
model_default,
optim_default,
inpt,
n_iters
)

self.assertTrue(torch.allclose(output_beta1_zero_after_load, output_default, atol=0.001))
output_default = iteration(model_default, optim_default, inpt, n_iters)

@optims([optim for optim in optim_db if optim.optim_cls.__name__ == "Adam"], dtypes=[torch.float32])
def test_beta1nonzero_then_loadstate_beta1zero_adam(self, device, dtype, optim_info):
self.assertTrue(
torch.allclose(output_beta1_zero_after_load, output_default, atol=0.001)
)

@optims(
[optim for optim in optim_db if optim.optim_cls.__name__ == "Adam"],
dtypes=[torch.float32],
)
def test_beta1nonzero_then_loadstate_beta1zero_adam(
self, device, dtype, optim_info
):
inpt = torch.ones((5,), dtype=dtype, device=device)
optim_inputs = optim_info.optim_inputs_func(device=device)

Expand All @@ -2289,19 +2301,24 @@ def init_model_and_optim(optim_input):
optim = optim_info.optim_cls(model.parameters(), **optim_input.kwargs)
return model, optim


# model and optimizer corresponding to beta1 = 0.0
model_beta1_zero, optim_beta1_zero = init_model_and_optim(
[optim_input for optim_input in optim_inputs
if "betas" in optim_input.kwargs
and optim_input.kwargs["betas"][0] == 0.0][0]
[
optim_input
for optim_input in optim_inputs
if "betas" in optim_input.kwargs
and optim_input.kwargs["betas"][0] == 0.0
][0]
)
# model and optimizer corresponding to default params
model_default, optim_default = init_model_and_optim(
[optim_input for optim_input in optim_inputs if optim_input.desc == "default"][0]
[
optim_input
for optim_input in optim_inputs
if optim_input.desc == "default"
][0]
)


# we should receive the same output if we do the following to our models :
# model_default: 1.train for n_iters, 2.load state (beta1=0.0), 3.train for n_iters
# model_beta1_zero: 4.train for n_iters
Expand All @@ -2319,12 +2336,7 @@ def iteration(model, optim, inpt, n_iters):
return output

# 1. train for n_iters
iteration(
model_default,
optim_default,
inpt,
n_iters
)
iteration(model_default, optim_default, inpt, n_iters)

# 2. load state
model_default.load_state_dict(model_beta1_zero.state_dict())
Expand All @@ -2333,30 +2345,23 @@ def iteration(model, optim, inpt, n_iters):
self.assertEqual(model_beta1_zero.state_dict(), model_default.state_dict())
self.assertEqual(optim_beta1_zero.state_dict(), optim_default.state_dict())


inpt = torch.rand((5,), dtype=dtype, device=device)
# 3. train for n_iters
output_default = iteration(
model_default,
optim_default,
inpt,
n_iters
)
output_default = iteration(model_default, optim_default, inpt, n_iters)

# 4. train for n_iters
output_beta1_zero_after_load = iteration(
model_beta1_zero,
optim_beta1_zero,
inpt,
n_iters
model_beta1_zero, optim_beta1_zero, inpt, n_iters
)

self.assertTrue(
torch.allclose(output_beta1_zero_after_load, output_default, atol=0.001)
)


@optims([optim for optim in optim_db if optim.optim_cls.__name__ == "Adam"], dtypes=[torch.float32])
@optims(
[optim for optim in optim_db if optim.optim_cls.__name__ == "Adam"],
dtypes=[torch.float32],
)
def test_correct_beta1(self, device, dtype, optim_info):
# we test correctness of the optimizer with beta1 = 0.0 by comparing it with the optimizer model
# with beta1 = 1e-6
Expand All @@ -2378,7 +2383,8 @@ def step(model, optim):

inpt = torch.ones((5,), dtype=dtype, device=device)
optim_inputs = [
optim_input for optim_input in optim_info.optim_inputs_func(device=device)
optim_input
for optim_input in optim_info.optim_inputs_func(device=device)
if ("betas" in optim_input.kwargs and optim_input.kwargs["betas"][0] == 0.0)
or optim_input.desc == "default"
]
Expand All @@ -2390,7 +2396,9 @@ def run_two_models(beta1_of_default):
ouputs = []
for optim_input in optim_inputs:
one_model_output = []
optim_input.kwargs["lr"] = lr # need lr high enough to see the difference
optim_input.kwargs["lr"] = (
lr # need lr high enough to see the difference
)
if optim_input.desc == "default":
optim_input.kwargs["betas"] = (beta1_of_default, 0.999)
model, optim = init_model_and_optim(optim_input)
Expand All @@ -2399,9 +2407,7 @@ def run_two_models(beta1_of_default):
for i in range(n_iters):
step(model, optim)
one_model_output.append(model(inpt))
ouputs.append(
torch.cat(one_model_output)
)
ouputs.append(torch.cat(one_model_output))
return ouputs

# our beta1=0 optimizer should have same performance as the default optimizer
Expand All @@ -2414,7 +2420,6 @@ def run_two_models(beta1_of_default):
self.assertFalse(torch.allclose(outputs[0], outputs[1]))



instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)


Expand Down
11 changes: 6 additions & 5 deletions torch/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,11 @@ def _init_group(
# Exponential moving average of gradient values
# case beta1 == 0, we don't need exp_avg

state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
) if beta1 > 0 else torch.zeros(0)
state["exp_avg"] = (
torch.zeros_like(p, memory_format=torch.preserve_format)
if beta1 > 0
else torch.zeros(0)
)

# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
Expand All @@ -187,7 +189,7 @@ def _init_group(
state["max_exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if beta1 > 0 :
if beta1 > 0:
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])

Expand Down Expand Up @@ -507,7 +509,6 @@ def _multi_tensor_adam(
capturable: bool,
differentiable: bool,
):

if len(params) == 0:
return

Expand Down

0 comments on commit 6f17d10

Please sign in to comment.