Skip to content

Commit

Permalink
feat fused_adamw (#938)
Browse files Browse the repository at this point in the history
* feat fused_adamw

* add test code

* add test code

* limit only on cuda

* limit only on cuda

* Simplify code

* Simplify code

* Simplify code

* Simplify code
  • Loading branch information
zyf654321 authored Oct 11, 2024
1 parent 0caf4c8 commit cfec228
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 2 deletions.
1 change: 1 addition & 0 deletions dipu/SupportedDiopiFunctions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ diopiForeachmulInpTensor
diopiForeachmulScalar
diopiForeachmulTensor
diopiForeachnormScalar
diopiFusedAdamW
diopiGather
diopiGe
diopiGeInp
Expand Down
10 changes: 10 additions & 0 deletions dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,16 @@
::diopiConstTensorHandle_t self_dtype_diopi = dipu::diopi_helper::toDiopiTensorHandle(self_dtype);
interface: diopiProd(ctx, out, self_dtype_diopi, nullptr)

- schema: "_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()"
custom_code_at_the_beginning: |
auto diopiTensorHandles_self = dipu::diopi_helper::toDiopiTensorHandleVector(self);
auto diopiTensorHandles_grads = dipu::diopi_helper::toDiopiConstTensorHandleVector(grads);
auto diopiTensorHandles_exp_avgs = dipu::diopi_helper::toDiopiTensorHandleVector(exp_avgs);
auto diopiTensorHandles_exp_avg_sqs = dipu::diopi_helper::toDiopiTensorHandleVector(exp_avg_sqs);
auto diopiTensorHandles_max_exp_avg_sqs = dipu::diopi_helper::toDiopiTensorHandleVector(max_exp_avg_sqs);
auto diopiTensorHandles_state_steps = dipu::diopi_helper::toDiopiConstTensorHandleVector(state_steps);
interface: diopiFusedAdamW(ctx, diopiTensorHandles_self.data(), diopiTensorHandles_grads.data(), diopiTensorHandles_exp_avgs.data(), diopiTensorHandles_exp_avg_sqs.data(), diopiTensorHandles_max_exp_avg_sqs.data(), diopiTensorHandles_state_steps.data(), static_cast<int64_t>(self.size()), lr, beta1, beta2, eps, weight_decay, amsgrad, maximize)

- schema: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
custom_code_at_the_beginning: |
const auto self_dtype = at::native::to(self, dtype);
Expand Down
213 changes: 213 additions & 0 deletions dipu/tests/python/unittests/test_adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import torch
import numpy as np
from torch_dipu.testing._internal.common_utils import (
TestCase,
run_tests,
onlyOn,
skipOn,
)


class TestFusedAdamW(TestCase):
def setUp(self):
self.weight_shape_list = [(), (16,), (4, 8), (12, 4, 8)]
self.lr_list = [0.001, 0.01, 0.001, 0.01]
self.beta1_list = [0.9, 0.9, 0.9, 0.9]
self.beta2_list = [0.999, 0.999, 0.999, 0.999]
self.eps_list = [1e-8, 1e-8, 1e-8, 1e-8]
self.weight_decay_list = [1e-2, 1e-3, 1e-2, 1e-3]
self.amsgrad_list = [False, False, True, True]
self.step_list = [2, 3, 4, 5]

def run_adamw_cpu(
self,
param,
param_grad,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
):
torch.optim._functional.adamw(
[param],
[param_grad],
[exp_avg],
[exp_avg_sq],
[max_exp_avg_sq],
[torch.tensor(float(step))],
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=False,
)
return param, exp_avg, exp_avg_sq, max_exp_avg_sq

def run_adamw_dipu(
self,
param,
param_grad,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
):
torch._fused_adamw_(
[param],
[param_grad],
[exp_avg],
[exp_avg_sq],
[max_exp_avg_sq],
[torch.tensor(float(step)).cuda()],
amsgrad=amsgrad,
lr=lr,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
eps=eps,
maximize=False,
grad_scale=None,
found_inf=None,
)
return param, exp_avg, exp_avg_sq, max_exp_avg_sq

def adamw_(self, dtype_):
for i in range(len(self.weight_shape_list)):
weight = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
weight_cpu = (
weight.cpu().to(torch.float32)
if dtype_ == torch.float16
else weight.cpu()
)
grad = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
grad_cpu = (
grad.cpu().to(torch.float32) if dtype_ == torch.float16 else grad.cpu()
)
m = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
m_cpu = m.cpu().to(torch.float32) if dtype_ == torch.float16 else m.cpu()
v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
v_cpu = v.cpu().to(torch.float32) if dtype_ == torch.float16 else v.cpu()
max_v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
max_v_cpu = (
max_v.cpu().to(torch.float32)
if dtype_ == torch.float16
else max_v.cpu()
)

lr = self.lr_list[i]
beta1 = self.beta1_list[i]
beta2 = self.beta2_list[i]
eps = self.eps_list[i]
weight_decay = self.weight_decay_list[i]
amsgrad = self.amsgrad_list[i]
step = self.step_list[i]

w_new_cpu, m_new_cpu, v_new_cpu, max_v_new_cpu = self.run_adamw_cpu(
weight_cpu,
grad_cpu,
m_cpu,
v_cpu,
max_v_cpu,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
)
w_new, m_new, v_new, max_v_new = self.run_adamw_dipu(
weight,
grad,
m,
v,
max_v,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
)

self.assertTrue(
torch.allclose(
w_new.cpu(),
(
w_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else w_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
torch.allclose(
m_new.cpu(),
(
m_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else m_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
)
self.assertTrue(
torch.allclose(
v_new.cpu(),
(
v_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else v_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
torch.allclose(
max_v_new.cpu(),
(
max_v_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else max_v_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
)

@skipOn(
["MLU", "NPU", "MUXI", "GCU", "DROPLET", "SUPA", "KLX"],
"The adamw fusion operator has not yet been connected to the dipu of these chips, and the chip name can be removed from the above list after being added later",
)
def test_adamw_fp16_(self):
self.adamw_(torch.float16)

@skipOn(
["MLU", "NPU", "MUXI", "GCU", "DROPLET", "SUPA", "KLX"],
"The adamw fusion operator has not yet been connected to the dipu of these chips, and the chip name can be removed from the above list after being added later",
)
def test_adamw_fp32_(self):
self.adamw_(torch.float32)


if __name__ == "__main__":
run_tests()
21 changes: 19 additions & 2 deletions dipu/torch_dipu/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,25 @@ def skipOnTorchVer(torchVer: str, reason: str = ""):
return unittest.skipIf(torch_dipu.dipu.get_dipu_torch_version() == torchVer, reason)


def skipOn(vendor: str, reason: str):
return unittest.skipIf(torch_dipu.dipu.vendor_type == vendor, reason)
@overload
def skipOn(vendor: str, reason: str): ...


@overload
def skipOn(vendor: List[str], reason: str): ...


def skipOn(vendor, reason: str):
if isinstance(vendor, str):
vendor_list = [vendor]
else:
vendor_list = vendor
return unittest.skipIf(
torch_dipu.dipu.vendor_type in vendor_list,
"skip on {} because {}".format(
vendor[0] if len(vendor) == 1 else vendor, reason
),
)


def skipIfDevcieCountLessThan(number_of_devices_required):
Expand Down

0 comments on commit cfec228

Please sign in to comment.