Skip to content

Commit

Permalink
Simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
zyf654321 committed Oct 10, 2024
1 parent 1f93b60 commit 9a8cf7d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 33 deletions.
38 changes: 6 additions & 32 deletions dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1327,38 +1327,12 @@

- schema: "_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()"
custom_code_at_the_beginning: |
std::vector<diopiTensorHandle_t> diopiTensorHandles_self(self.size());
for(size_t i=0; i < self.size(); ++i){
diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(self.at(i));
diopiTensorHandle_t handle = const_cast<diopiTensorHandle_t>(const_handle);
diopiTensorHandles_self[i] = handle;
}
std::vector<diopiConstTensorHandle_t> diopiTensorHandles_grads(grads.size());
for(size_t i=0; i < grads.size(); ++i){
diopiTensorHandles_grads[i] = dipu::diopi_helper::toDiopiTensorHandle(grads.at(i));
}
std::vector<diopiTensorHandle_t> diopiTensorHandles_exp_avgs(exp_avgs.size());
for(size_t i=0; i < exp_avgs.size(); ++i){
diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(exp_avgs.at(i));
diopiTensorHandle_t handle = const_cast<diopiTensorHandle_t>(const_handle);
diopiTensorHandles_exp_avgs[i] = handle;
}
std::vector<diopiTensorHandle_t> diopiTensorHandles_exp_avg_sqs(exp_avg_sqs.size());
for(size_t i=0; i < exp_avg_sqs.size(); ++i){
diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(exp_avg_sqs.at(i));
diopiTensorHandle_t handle = const_cast<diopiTensorHandle_t>(const_handle);
diopiTensorHandles_exp_avg_sqs[i] = handle;
}
std::vector<diopiTensorHandle_t> diopiTensorHandles_max_exp_avg_sqs(max_exp_avg_sqs.size());
for(size_t i=0; i < max_exp_avg_sqs.size(); ++i){
diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(max_exp_avg_sqs.at(i));
diopiTensorHandle_t handle = const_cast<diopiTensorHandle_t>(const_handle);
diopiTensorHandles_max_exp_avg_sqs[i] = handle;
}
std::vector<diopiConstTensorHandle_t> diopiTensorHandles_state_steps(state_steps.size(), nullptr);
for(size_t i=0; i < state_steps.size(); ++i){
diopiTensorHandles_state_steps[i] = dipu::diopi_helper::toDiopiTensorHandle(state_steps.at(i));
}
auto diopiTensorHandles_self = dipu::diopi_helper::toDiopiTensorHandleVector(self);
auto diopiTensorHandles_grads = dipu::diopi_helper::toDiopiConstTensorHandleVector(grads);
auto diopiTensorHandles_exp_avgs = dipu::diopi_helper::toDiopiTensorHandleVector(exp_avgs);
auto diopiTensorHandles_exp_avg_sqs = dipu::diopi_helper::toDiopiTensorHandleVector(exp_avg_sqs);
auto diopiTensorHandles_max_exp_avg_sqs = dipu::diopi_helper::toDiopiTensorHandleVector(max_exp_avg_sqs);
auto diopiTensorHandles_state_steps = dipu::diopi_helper::toDiopiConstTensorHandleVector(state_steps);
interface: diopiFusedAdamW(ctx, diopiTensorHandles_self.data(), diopiTensorHandles_grads.data(), diopiTensorHandles_exp_avgs.data(), diopiTensorHandles_exp_avg_sqs.data(), diopiTensorHandles_max_exp_avg_sqs.data(), diopiTensorHandles_state_steps.data(), static_cast<int64_t>(self.size()), lr, beta1, beta2, eps, weight_decay, amsgrad, maximize)

- schema: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
Expand Down
2 changes: 1 addition & 1 deletion dipu/tests/python/unittests/test_adamw.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import numpy as np
from torch_dipu.testing._internal.common_utils import TestCase, run_tests, onlyOn
from torch_dipu.testing._internal.common_utils import TestCase, run_tests, onlyOn, skipOn


class TestFusedAdamW(TestCase):
Expand Down

0 comments on commit 9a8cf7d

Please sign in to comment.