From 481bc8e8c3cff48b5e57f8ad4c99d833f0c19548 Mon Sep 17 00:00:00 2001 From: HuayiL <442488254@qq.com> Date: Fri, 1 Nov 2024 15:36:57 +0800 Subject: [PATCH] Update test_adamw.py --- dipu/tests/python/unittests/test_adamw.py | 119 +++++++--------------- 1 file changed, 39 insertions(+), 80 deletions(-) diff --git a/dipu/tests/python/unittests/test_adamw.py b/dipu/tests/python/unittests/test_adamw.py index 69ea9d495..bef31c2ea 100644 --- a/dipu/tests/python/unittests/test_adamw.py +++ b/dipu/tests/python/unittests/test_adamw.py @@ -6,7 +6,8 @@ onlyOn, skipOn, ) - +#`fused=True` requires all the params to be floating point Tensors of supported devices: ['cuda', 'xpu', 'privateuseone']. +#So we use fused=False and cuda results to compare with fused torch_dipu results. class TestFusedAdamW(TestCase): def setUp(self): @@ -17,73 +18,56 @@ def setUp(self): self.eps_list = [1e-8, 1e-8, 1e-8, 1e-8] self.weight_decay_list = [1e-2, 1e-3, 1e-2, 1e-3] self.amsgrad_list = [False, False, True, True] - self.step_list = [2, 3, 4, 5] def run_adamw_cpu( self, param, param_grad, - exp_avg, - exp_avg_sq, - max_exp_avg_sq, lr, beta1, beta2, eps, - step, weight_decay, amsgrad, ): - torch.optim._functional.adamw( - [param], - [param_grad], - [exp_avg], - [exp_avg_sq], - [max_exp_avg_sq], - [torch.tensor(float(step))], - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - return param, exp_avg, exp_avg_sq, max_exp_avg_sq - + param.grad = param_grad + optimizer = torch.optim.AdamW(params = [param], + lr = lr, + betas = (beta1,beta2), + eps=eps, + weight_decay=weight_decay, + amsgrad = amsgrad, + fused = False) + optimizer.step() + state_index = 0 + exp_avg = optimizer.state_dict()["state"][state_index]["exp_avg"] + exp_avg_sq = optimizer.state_dict()["state"][state_index]["exp_avg_sq"] + return param, exp_avg, exp_avg_sq + def run_adamw_dipu( self, param, param_grad, - exp_avg, - exp_avg_sq, - max_exp_avg_sq, lr, beta1, beta2, eps, - step, weight_decay, amsgrad, ): - torch._fused_adamw_( - [param], - [param_grad], - [exp_avg], - [exp_avg_sq], - [max_exp_avg_sq], - [torch.tensor(float(step)).cuda()], - amsgrad=amsgrad, - lr=lr, - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=False, - grad_scale=None, - found_inf=None, - ) - return param, exp_avg, exp_avg_sq, max_exp_avg_sq + param.grad = param_grad + optimizer = torch.optim.AdamW(params = [param], + lr = lr, + betas = (beta1,beta2), + eps=eps, + weight_decay=weight_decay, + amsgrad = amsgrad, + fused = True) + optimizer.step() + state_index = 0 + exp_avg = optimizer.state_dict()["state"][state_index]["exp_avg"] + exp_avg_sq = optimizer.state_dict()["state"][state_index]["exp_avg_sq"] + return param, exp_avg, exp_avg_sq def adamw_(self, dtype_): for i in range(len(self.weight_shape_list)): @@ -93,20 +77,12 @@ def adamw_(self, dtype_): if dtype_ == torch.float16 else weight.cpu() ) + weight_fused_cpu = weight_cpu.clone().detach() grad = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() grad_cpu = ( grad.cpu().to(torch.float32) if dtype_ == torch.float16 else grad.cpu() ) - m = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() - m_cpu = m.cpu().to(torch.float32) if dtype_ == torch.float16 else m.cpu() - v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() - v_cpu = v.cpu().to(torch.float32) if dtype_ == torch.float16 else v.cpu() - max_v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() - max_v_cpu = ( - max_v.cpu().to(torch.float32) - if dtype_ == torch.float16 - else max_v.cpu() - ) + grad_fused_cpu = grad_cpu.clone().detach() lr = self.lr_list[i] beta1 = self.beta1_list[i] @@ -114,33 +90,24 @@ def adamw_(self, dtype_): eps = self.eps_list[i] weight_decay = self.weight_decay_list[i] amsgrad = self.amsgrad_list[i] - step = self.step_list[i] - w_new_cpu, m_new_cpu, v_new_cpu, max_v_new_cpu = self.run_adamw_cpu( + w_new_cpu, m_new_cpu, v_new_cpu = self.run_adamw_cpu( weight_cpu, grad_cpu, - m_cpu, - v_cpu, - max_v_cpu, lr, beta1, beta2, eps, - step, weight_decay, amsgrad, ) - w_new, m_new, v_new, max_v_new = self.run_adamw_dipu( + w_new, m_new, v_new= self.run_adamw_dipu( weight, grad, - m, - v, - max_v, lr, beta1, beta2, eps, - step, weight_decay, amsgrad, ) @@ -155,8 +122,11 @@ def adamw_(self, dtype_): ), atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, - equal_nan=True, + equal_nan = False, ), + ) + + self.assertTrue( torch.allclose( m_new.cpu(), ( @@ -166,7 +136,7 @@ def adamw_(self, dtype_): ), atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, - equal_nan=True, + equal_nan = False, ), ) self.assertTrue( @@ -179,18 +149,7 @@ def adamw_(self, dtype_): ), atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, - equal_nan=True, - ), - torch.allclose( - max_v_new.cpu(), - ( - max_v_new_cpu.to(torch.float16) - if dtype_ == torch.float16 - else max_v_new_cpu - ), - atol=2e-2 if dtype_ == torch.float16 else 1e-2, - rtol=4e-3 if dtype_ == torch.float16 else 2e-3, - equal_nan=True, + equal_nan = False, ), )