Skip to content

Commit

Permalink
Align diopiStd to device with torch version>=2.0 (#921)
Browse files Browse the repository at this point in the history
* update diopi funciton std yaml, add test

* fix format, add >= torch version check in diopi funcitons generate

* remove std test reduction

* support diopi std in torch2.0 and >=torch2.1

* update DIOPI

* refactor autogen torch version check & update diopi

* merge diopi to diopi_main

* update test_std_correction with skip on MLU

* fix cuda generator

* fix compile

* update DIOPI

---------

Co-authored-by: caikun-pjlab <[email protected]>
  • Loading branch information
DoorKickers and caikun-pjlab authored Aug 2, 2024
1 parent 1bffe33 commit ae11430
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 26 deletions.
21 changes: 18 additions & 3 deletions dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,10 +1203,25 @@ def main():
continue

# filter torch version
in_torch_vers = merged_fun_config.get("torch_ver", None)
supported_torch_ver_list = merged_fun_config.get("torch_ver", None)
cur_torch_ver = merged_fun_config.get("current_torch_ver", None)
if in_torch_vers is not None and cur_torch_ver not in in_torch_vers:
continue

if supported_torch_ver_list != None:
exclude_torch_ver_list = []
include_torch_ver_list = []
all_include = False
for supported_torch_ver in supported_torch_ver_list:
if supported_torch_ver.startswith("-"):
exclude_torch_ver_list.append(supported_torch_ver[1:])
elif supported_torch_ver == "all":
all_include = True
else:
include_torch_ver_list.append(supported_torch_ver)

if (cur_torch_ver in exclude_torch_ver_list) or (
all_include == False and (cur_torch_ver not in include_torch_ver_list)
):
continue

fun_code, register_code = functions_code_gen(merged_fun_config)

Expand Down
28 changes: 18 additions & 10 deletions dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -838,32 +838,40 @@
custom_code_at_the_beginning: |
c10::DimVector output_shape = infer_reduce_op_shape(self.sizes(), dim.value_or(c10::DimVector()), keepdim);
auto out = nodispatch::empty(output_shape, self.options());
bool unbiased = correction.value_or(1) == 1;
::diopiScalar_t correctionDiopiScalar;
const ::diopiScalar_t* correctionDiopiScalarPtr = nullptr;
if (correction.has_value()) {
correctionDiopiScalar = dipu::diopi_helper::toDiopiScalar(correction.value());
correctionDiopiScalarPtr = &correctionDiopiScalar;
}
::diopiSize_t diopi_size = toDiopiSize(dim);
interface: diopiStd(ctx, out, self, diopi_size, unbiased);
interface: diopiStd(ctx, out, self, diopi_size, correctionDiopiScalarPtr);

- schema: "std.correction_out(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)"
torch_ver: ["20000",]
custom_code_at_the_beginning: |
::diopiSize_t diopi_size = toDiopiSize(dim);
bool unbiased = correction.value_or(1) == 1;
interface: diopiStd(ctx, out, self, diopi_size, unbiased);
::diopiScalar_t correctionDiopiScalar;
const ::diopiScalar_t* correctionDiopiScalarPtr = nullptr;
if (correction.has_value()) {
correctionDiopiScalar = dipu::diopi_helper::toDiopiScalar(correction.value());
correctionDiopiScalarPtr = &correctionDiopiScalar;
}
interface: diopiStd(ctx, out, self, diopi_size, correctionDiopiScalarPtr);

- schema: "std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor"
torch_ver: ["20100", "20101", "20202"]
torch_ver: [all, "-20000"]
custom_code_at_the_beginning: |
c10::DimVector output_shape = infer_reduce_op_shape(self.sizes(), dim.value_or(c10::DimVector()), keepdim);
auto out = nodispatch::empty(output_shape, self.options());
bool unbiased = correction.value_or(1).toLong() == 1;
::diopiSize_t diopi_size = toDiopiSize(dim);
interface: diopiStd(ctx, out, self, diopi_size, unbiased);
interface: diopiStd(ctx, out, self, diopi_size, correctionDiopiScalarPtr);

- schema: "std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)"
torch_ver: ["20100", "20101", "20202"]
torch_ver: [all, "-20000"]
custom_code_at_the_beginning: |
::diopiSize_t diopi_size = toDiopiSize(dim);
bool unbiased = correction.value_or(1).toLong() == 1;
interface: diopiStd(ctx, out, self, diopi_size, unbiased);
interface: diopiStd(ctx, out, self, diopi_size, correctionDiopiScalarPtr)

- schema: "linear_backward(Tensor input, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)"
device: [all, -cuda, -muxi, -ascend]
Expand Down
13 changes: 12 additions & 1 deletion dipu/tests/python/unittests/test_mean_std.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2023, DeepLink.
import torch
import torch_dipu
from torch_dipu.testing._internal.common_utils import TestCase, run_tests
from torch_dipu.testing._internal.common_utils import TestCase, run_tests, skipOn


class TestMeanStd(TestCase):
Expand Down Expand Up @@ -61,6 +61,17 @@ def test_std(self):
)
)

@skipOn("MLU", "camb does not support this type")
def test_std_correction(self):
self.assertTrue(
torch.allclose(
torch.std(self.a, dim=-1, correction=20).cpu(),
torch.std(self.a.cpu(), dim=-1, correction=20),
atol=1e-3,
rtol=1e-3,
)
)


if __name__ == "__main__":
run_tests()
2 changes: 0 additions & 2 deletions dipu/tests/pytorch_config_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@
"TestReductionsDIPU": {
"test_ref_large_input_1D",
"test_ref_large_input_64bit_indexing",
# will fail because diopiStd not align with torch.std, will fix later
"test_warn_invalid_degrees_of_freedom",
},
}

Expand Down
2 changes: 1 addition & 1 deletion dipu/torch_dipu/csrc_dipu/diopirt/diopirt_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ DIOPI_RT_API diopiError_t diopiGeneratorSetSeedAndOffset(
diopiGeneratorHandle_t th, uint64_t seed, uint64_t offset) {
auto generator = reinterpret_cast<at::Generator*>(th);
auto gen_impl = at::check_generator<dipu::DIPUGeneratorImpl>(*generator);
gen_impl->set_offset(offset);
gen_impl->set_current_seed(seed);
gen_impl->set_offset(offset);
return diopiSuccess;
}

Expand Down
6 changes: 3 additions & 3 deletions dipu/torch_dipu/csrc_dipu/runtime/core/DIPUGeneratorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ at::Generator createDIPUGenerator(at::DeviceIndex device_index) {
*/
DIPUGeneratorImpl::DIPUGeneratorImpl(at::DeviceIndex device_index)
: c10::GeneratorImpl{at::Device(dipu::DIPU_DEVICE_TYPE, device_index),
at::DispatchKeySet(dipu::DIPU_DISPATCH_KEY)},
offset_(0),
state_need_reset_(true) {}
at::DispatchKeySet(dipu::DIPU_DISPATCH_KEY)} {}

/**
* Sets the seed to be used by MTGP
Expand All @@ -91,6 +89,7 @@ DIPUGeneratorImpl::DIPUGeneratorImpl(at::DeviceIndex device_index)
*/
void DIPUGeneratorImpl::set_current_seed(uint64_t seed) {
seed_ = seed;
offset_ = 0;
state_need_reset_ = true;
}

Expand Down Expand Up @@ -137,6 +136,7 @@ DIPUGeneratorImpl* DIPUGeneratorImpl::clone_impl() const {
createDIPUGenerator(this->device().index()).unsafeReleaseGeneratorImpl());
TORCH_CHECK(gen != nullptr);
gen->set_current_seed(this->seed_);
gen->set_offset(offset_);
auto state = this->state_;
const auto& state_clone = state.clone();
gen->set_state(*state_clone.getIntrusivePtr());
Expand Down
14 changes: 10 additions & 4 deletions dipu/torch_dipu/csrc_dipu/runtime/core/DIPUGeneratorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@ class DIPUGeneratorImpl : public c10::GeneratorImpl {
// not match the order elsewhere. we will change to keep the order from
// oldest-compatiable to latest vesion.
#if DIPU_TORCH_VERSION == 20100 || DIPU_TORCH_VERSION == 20101
void set_offset(uint64_t offset) override { offset_ = offset; }
void set_offset(uint64_t offset) override {
offset_ = offset;
state_need_reset_ = true;
}
uint64_t get_offset() const override { return offset_; }

#else // # temp solution, default use torch2.0.0
virtual void set_offset(uint64_t offset) { offset_ = offset; }
virtual void set_offset(uint64_t offset) {
offset_ = offset;
state_need_reset_ = true;
}
virtual uint64_t get_offset() const { return offset_; }

#endif
Expand All @@ -40,10 +46,10 @@ class DIPUGeneratorImpl : public c10::GeneratorImpl {
virtual void update_state() const = 0;

DIPUGeneratorImpl* clone_impl() const override;
volatile uint64_t offset_;
volatile uint64_t offset_ = 0;
uint64_t seed_ = c10::default_rng_seed_val;
mutable at::Tensor state_;
mutable bool state_need_reset_;
mutable bool state_need_reset_ = true;
};

at::Generator& getDefaultDIPUGenerator(at::DeviceIndex device_index = -1);
Expand Down
23 changes: 22 additions & 1 deletion dipu/torch_dipu/csrc_dipu/vendor/cuda/CudaGeneratorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,27 @@ class CUDAGeneratorImpl : public dipu::DIPUGeneratorImpl {
state_size == total_size || state_size == total_size - offset_size,
"RNG state is wrong size");

// 1. set seed and offset
bool no_philox_seed = false;
if (state_size == total_size - offset_size) {
no_philox_seed = true;
}

uint64_t input_seed = 0;
#if DIPU_TORCH_VERSION == 20000
auto new_rng_state = state.data<uint8_t>();
#else
auto new_rng_state = state.data_dtype_initialized<uint8_t>();
#endif
memcpy(&input_seed, new_rng_state, seed_size);
this->set_current_seed(input_seed);
int64_t philox_offset = 0;
if (!no_philox_seed) {
memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
}
this->set_offset(static_cast<uint64_t>(philox_offset));

// 2. set state
at::Tensor state_tmp(
state.shallow_copy_and_detach(state.version_counter(), true));
state_ = state_tmp;
Expand All @@ -44,7 +65,7 @@ class CUDAGeneratorImpl : public dipu::DIPUGeneratorImpl {
// THCGenerator struct was an array of curandStateMtgp32s.
memset(rng_state, -1, states_size);
uint64_t current_seed = this->current_seed();
int64_t offset = 0;
int64_t offset = this->get_offset();
memcpy(rng_state + states_size, &current_seed, seed_size);
memcpy(rng_state + states_size + seed_size, &offset, offset_size);
state_need_reset_ = false;
Expand Down

0 comments on commit ae11430

Please sign in to comment.