From cfec2283f7559c6cb5f5c94c8478f5e6ce69af7e Mon Sep 17 00:00:00 2001 From: zyf654321 <88707387+zyf654321@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:05:50 +0800 Subject: [PATCH] feat fused_adamw (#938) * feat fused_adamw * add test code * add test code * limit only on cuda * limit only on cuda * Simplify code * Simplify code * Simplify code * Simplify code --- dipu/SupportedDiopiFunctions.txt | 1 + .../diopi_functions.yaml | 10 + dipu/tests/python/unittests/test_adamw.py | 213 ++++++++++++++++++ .../testing/_internal/common_utils.py | 21 +- 4 files changed, 243 insertions(+), 2 deletions(-) create mode 100644 dipu/tests/python/unittests/test_adamw.py diff --git a/dipu/SupportedDiopiFunctions.txt b/dipu/SupportedDiopiFunctions.txt index 4d75be61b..63ebcbf4d 100644 --- a/dipu/SupportedDiopiFunctions.txt +++ b/dipu/SupportedDiopiFunctions.txt @@ -100,6 +100,7 @@ diopiForeachmulInpTensor diopiForeachmulScalar diopiForeachmulTensor diopiForeachnormScalar +diopiFusedAdamW diopiGather diopiGe diopiGeInp diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml index 7ec22f5eb..242a953a6 100755 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml @@ -1334,6 +1334,16 @@ ::diopiConstTensorHandle_t self_dtype_diopi = dipu::diopi_helper::toDiopiTensorHandle(self_dtype); interface: diopiProd(ctx, out, self_dtype_diopi, nullptr) +- schema: "_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()" + custom_code_at_the_beginning: | + auto diopiTensorHandles_self = dipu::diopi_helper::toDiopiTensorHandleVector(self); + auto diopiTensorHandles_grads = dipu::diopi_helper::toDiopiConstTensorHandleVector(grads); + auto diopiTensorHandles_exp_avgs = dipu::diopi_helper::toDiopiTensorHandleVector(exp_avgs); + auto diopiTensorHandles_exp_avg_sqs = dipu::diopi_helper::toDiopiTensorHandleVector(exp_avg_sqs); + auto diopiTensorHandles_max_exp_avg_sqs = dipu::diopi_helper::toDiopiTensorHandleVector(max_exp_avg_sqs); + auto diopiTensorHandles_state_steps = dipu::diopi_helper::toDiopiConstTensorHandleVector(state_steps); + interface: diopiFusedAdamW(ctx, diopiTensorHandles_self.data(), diopiTensorHandles_grads.data(), diopiTensorHandles_exp_avgs.data(), diopiTensorHandles_exp_avg_sqs.data(), diopiTensorHandles_max_exp_avg_sqs.data(), diopiTensorHandles_state_steps.data(), static_cast(self.size()), lr, beta1, beta2, eps, weight_decay, amsgrad, maximize) + - schema: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) custom_code_at_the_beginning: | const auto self_dtype = at::native::to(self, dtype); diff --git a/dipu/tests/python/unittests/test_adamw.py b/dipu/tests/python/unittests/test_adamw.py new file mode 100644 index 000000000..69ea9d495 --- /dev/null +++ b/dipu/tests/python/unittests/test_adamw.py @@ -0,0 +1,213 @@ +import torch +import numpy as np +from torch_dipu.testing._internal.common_utils import ( + TestCase, + run_tests, + onlyOn, + skipOn, +) + + +class TestFusedAdamW(TestCase): + def setUp(self): + self.weight_shape_list = [(), (16,), (4, 8), (12, 4, 8)] + self.lr_list = [0.001, 0.01, 0.001, 0.01] + self.beta1_list = [0.9, 0.9, 0.9, 0.9] + self.beta2_list = [0.999, 0.999, 0.999, 0.999] + self.eps_list = [1e-8, 1e-8, 1e-8, 1e-8] + self.weight_decay_list = [1e-2, 1e-3, 1e-2, 1e-3] + self.amsgrad_list = [False, False, True, True] + self.step_list = [2, 3, 4, 5] + + def run_adamw_cpu( + self, + param, + param_grad, + exp_avg, + exp_avg_sq, + max_exp_avg_sq, + lr, + beta1, + beta2, + eps, + step, + weight_decay, + amsgrad, + ): + torch.optim._functional.adamw( + [param], + [param_grad], + [exp_avg], + [exp_avg_sq], + [max_exp_avg_sq], + [torch.tensor(float(step))], + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + return param, exp_avg, exp_avg_sq, max_exp_avg_sq + + def run_adamw_dipu( + self, + param, + param_grad, + exp_avg, + exp_avg_sq, + max_exp_avg_sq, + lr, + beta1, + beta2, + eps, + step, + weight_decay, + amsgrad, + ): + torch._fused_adamw_( + [param], + [param_grad], + [exp_avg], + [exp_avg_sq], + [max_exp_avg_sq], + [torch.tensor(float(step)).cuda()], + amsgrad=amsgrad, + lr=lr, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=False, + grad_scale=None, + found_inf=None, + ) + return param, exp_avg, exp_avg_sq, max_exp_avg_sq + + def adamw_(self, dtype_): + for i in range(len(self.weight_shape_list)): + weight = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() + weight_cpu = ( + weight.cpu().to(torch.float32) + if dtype_ == torch.float16 + else weight.cpu() + ) + grad = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() + grad_cpu = ( + grad.cpu().to(torch.float32) if dtype_ == torch.float16 else grad.cpu() + ) + m = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() + m_cpu = m.cpu().to(torch.float32) if dtype_ == torch.float16 else m.cpu() + v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() + v_cpu = v.cpu().to(torch.float32) if dtype_ == torch.float16 else v.cpu() + max_v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() + max_v_cpu = ( + max_v.cpu().to(torch.float32) + if dtype_ == torch.float16 + else max_v.cpu() + ) + + lr = self.lr_list[i] + beta1 = self.beta1_list[i] + beta2 = self.beta2_list[i] + eps = self.eps_list[i] + weight_decay = self.weight_decay_list[i] + amsgrad = self.amsgrad_list[i] + step = self.step_list[i] + + w_new_cpu, m_new_cpu, v_new_cpu, max_v_new_cpu = self.run_adamw_cpu( + weight_cpu, + grad_cpu, + m_cpu, + v_cpu, + max_v_cpu, + lr, + beta1, + beta2, + eps, + step, + weight_decay, + amsgrad, + ) + w_new, m_new, v_new, max_v_new = self.run_adamw_dipu( + weight, + grad, + m, + v, + max_v, + lr, + beta1, + beta2, + eps, + step, + weight_decay, + amsgrad, + ) + + self.assertTrue( + torch.allclose( + w_new.cpu(), + ( + w_new_cpu.to(torch.float16) + if dtype_ == torch.float16 + else w_new_cpu + ), + atol=2e-2 if dtype_ == torch.float16 else 1e-2, + rtol=4e-3 if dtype_ == torch.float16 else 2e-3, + equal_nan=True, + ), + torch.allclose( + m_new.cpu(), + ( + m_new_cpu.to(torch.float16) + if dtype_ == torch.float16 + else m_new_cpu + ), + atol=2e-2 if dtype_ == torch.float16 else 1e-2, + rtol=4e-3 if dtype_ == torch.float16 else 2e-3, + equal_nan=True, + ), + ) + self.assertTrue( + torch.allclose( + v_new.cpu(), + ( + v_new_cpu.to(torch.float16) + if dtype_ == torch.float16 + else v_new_cpu + ), + atol=2e-2 if dtype_ == torch.float16 else 1e-2, + rtol=4e-3 if dtype_ == torch.float16 else 2e-3, + equal_nan=True, + ), + torch.allclose( + max_v_new.cpu(), + ( + max_v_new_cpu.to(torch.float16) + if dtype_ == torch.float16 + else max_v_new_cpu + ), + atol=2e-2 if dtype_ == torch.float16 else 1e-2, + rtol=4e-3 if dtype_ == torch.float16 else 2e-3, + equal_nan=True, + ), + ) + + @skipOn( + ["MLU", "NPU", "MUXI", "GCU", "DROPLET", "SUPA", "KLX"], + "The adamw fusion operator has not yet been connected to the dipu of these chips, and the chip name can be removed from the above list after being added later", + ) + def test_adamw_fp16_(self): + self.adamw_(torch.float16) + + @skipOn( + ["MLU", "NPU", "MUXI", "GCU", "DROPLET", "SUPA", "KLX"], + "The adamw fusion operator has not yet been connected to the dipu of these chips, and the chip name can be removed from the above list after being added later", + ) + def test_adamw_fp32_(self): + self.adamw_(torch.float32) + + +if __name__ == "__main__": + run_tests() diff --git a/dipu/torch_dipu/testing/_internal/common_utils.py b/dipu/torch_dipu/testing/_internal/common_utils.py index 4838b1416..7e0a5fe2b 100644 --- a/dipu/torch_dipu/testing/_internal/common_utils.py +++ b/dipu/torch_dipu/testing/_internal/common_utils.py @@ -65,8 +65,25 @@ def skipOnTorchVer(torchVer: str, reason: str = ""): return unittest.skipIf(torch_dipu.dipu.get_dipu_torch_version() == torchVer, reason) -def skipOn(vendor: str, reason: str): - return unittest.skipIf(torch_dipu.dipu.vendor_type == vendor, reason) +@overload +def skipOn(vendor: str, reason: str): ... + + +@overload +def skipOn(vendor: List[str], reason: str): ... + + +def skipOn(vendor, reason: str): + if isinstance(vendor, str): + vendor_list = [vendor] + else: + vendor_list = vendor + return unittest.skipIf( + torch_dipu.dipu.vendor_type in vendor_list, + "skip on {} because {}".format( + vendor[0] if len(vendor) == 1 else vendor, reason + ), + ) def skipIfDevcieCountLessThan(number_of_devices_required):