Skip to content

Commit

Permalink
Update test_adamw.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrench-Git committed Nov 1, 2024
1 parent 9dd3808 commit 481bc8e
Showing 1 changed file with 39 additions and 80 deletions.
119 changes: 39 additions & 80 deletions dipu/tests/python/unittests/test_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
onlyOn,
skipOn,
)

#`fused=True` requires all the params to be floating point Tensors of supported devices: ['cuda', 'xpu', 'privateuseone'].
#So we use fused=False and cuda results to compare with fused torch_dipu results.

class TestFusedAdamW(TestCase):
def setUp(self):
Expand All @@ -17,73 +18,56 @@ def setUp(self):
self.eps_list = [1e-8, 1e-8, 1e-8, 1e-8]
self.weight_decay_list = [1e-2, 1e-3, 1e-2, 1e-3]
self.amsgrad_list = [False, False, True, True]
self.step_list = [2, 3, 4, 5]

def run_adamw_cpu(
self,
param,
param_grad,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
):
torch.optim._functional.adamw(
[param],
[param_grad],
[exp_avg],
[exp_avg_sq],
[max_exp_avg_sq],
[torch.tensor(float(step))],
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=False,
)
return param, exp_avg, exp_avg_sq, max_exp_avg_sq

param.grad = param_grad
optimizer = torch.optim.AdamW(params = [param],
lr = lr,
betas = (beta1,beta2),
eps=eps,
weight_decay=weight_decay,
amsgrad = amsgrad,
fused = False)
optimizer.step()
state_index = 0
exp_avg = optimizer.state_dict()["state"][state_index]["exp_avg"]
exp_avg_sq = optimizer.state_dict()["state"][state_index]["exp_avg_sq"]
return param, exp_avg, exp_avg_sq

def run_adamw_dipu(
self,
param,
param_grad,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
):
torch._fused_adamw_(
[param],
[param_grad],
[exp_avg],
[exp_avg_sq],
[max_exp_avg_sq],
[torch.tensor(float(step)).cuda()],
amsgrad=amsgrad,
lr=lr,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
eps=eps,
maximize=False,
grad_scale=None,
found_inf=None,
)
return param, exp_avg, exp_avg_sq, max_exp_avg_sq
param.grad = param_grad
optimizer = torch.optim.AdamW(params = [param],
lr = lr,
betas = (beta1,beta2),
eps=eps,
weight_decay=weight_decay,
amsgrad = amsgrad,
fused = True)
optimizer.step()
state_index = 0
exp_avg = optimizer.state_dict()["state"][state_index]["exp_avg"]
exp_avg_sq = optimizer.state_dict()["state"][state_index]["exp_avg_sq"]
return param, exp_avg, exp_avg_sq

def adamw_(self, dtype_):
for i in range(len(self.weight_shape_list)):
Expand All @@ -93,54 +77,37 @@ def adamw_(self, dtype_):
if dtype_ == torch.float16
else weight.cpu()
)
weight_fused_cpu = weight_cpu.clone().detach()
grad = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
grad_cpu = (
grad.cpu().to(torch.float32) if dtype_ == torch.float16 else grad.cpu()
)
m = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
m_cpu = m.cpu().to(torch.float32) if dtype_ == torch.float16 else m.cpu()
v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
v_cpu = v.cpu().to(torch.float32) if dtype_ == torch.float16 else v.cpu()
max_v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
max_v_cpu = (
max_v.cpu().to(torch.float32)
if dtype_ == torch.float16
else max_v.cpu()
)
grad_fused_cpu = grad_cpu.clone().detach()

lr = self.lr_list[i]
beta1 = self.beta1_list[i]
beta2 = self.beta2_list[i]
eps = self.eps_list[i]
weight_decay = self.weight_decay_list[i]
amsgrad = self.amsgrad_list[i]
step = self.step_list[i]

w_new_cpu, m_new_cpu, v_new_cpu, max_v_new_cpu = self.run_adamw_cpu(
w_new_cpu, m_new_cpu, v_new_cpu = self.run_adamw_cpu(
weight_cpu,
grad_cpu,
m_cpu,
v_cpu,
max_v_cpu,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
)
w_new, m_new, v_new, max_v_new = self.run_adamw_dipu(
w_new, m_new, v_new= self.run_adamw_dipu(
weight,
grad,
m,
v,
max_v,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
)
Expand All @@ -155,8 +122,11 @@ def adamw_(self, dtype_):
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
equal_nan = False,
),
)

self.assertTrue(
torch.allclose(
m_new.cpu(),
(
Expand All @@ -166,7 +136,7 @@ def adamw_(self, dtype_):
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
equal_nan = False,
),
)
self.assertTrue(
Expand All @@ -179,18 +149,7 @@ def adamw_(self, dtype_):
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
torch.allclose(
max_v_new.cpu(),
(
max_v_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else max_v_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
equal_nan = False,
),
)

Expand Down

0 comments on commit 481bc8e

Please sign in to comment.