Skip to content

Commit

Permalink
[DIPU]fix cuda generator set_state (#932)
Browse files Browse the repository at this point in the history
* fix cuda generator set_state

* add test
  • Loading branch information
caikun-pjlab authored Aug 20, 2024
1 parent dcda998 commit b638174
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
27 changes: 24 additions & 3 deletions dipu/tests/python/unittests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import torch
import torch_dipu
from torch_dipu import diputype
from torch_dipu.testing._internal.common_utils import TestCase, run_tests
from torch_dipu.testing._internal.common_utils import (
TestCase,
run_tests,
onlyOn,
)


class TestGenerator(TestCase):
Expand All @@ -20,13 +24,13 @@ def test_python_api(self):
torch.cuda.manual_seed(i)

state = torch.cuda.get_rng_state(0)
new_state = torch.ones_like(state)
new_state = torch.ones_like(state) * 4
torch.cuda.set_rng_state(new_state, 0)
current_state = torch.cuda.get_rng_state(0)
self.assertTrue(
torch.allclose(
current_state,
torch.tensor(1, device=current_state.device, dtype=current_state.dtype),
torch.tensor(4, device=current_state.device, dtype=current_state.dtype),
)
)

Expand Down Expand Up @@ -194,6 +198,23 @@ def test_default_generators(self):
torch.cuda.default_generators[0].manual_seed(1)
self.assertEqual(torch.cuda.default_generators[0].initial_seed(), 1)

@onlyOn("CUDA")
def test_cuda_generator(self):
state = torch.cuda.get_rng_state(0)
state[-16] = 4
state[-15:-8] = 0
state[-8:] = 0
torch.cuda.set_rng_state(state)
self.assertEqual(torch.cuda.initial_seed(), 4)

# invalid offset, offset must be a multiple of 4
state[-8:] = 1
try:
torch.cuda.set_rng_state(state)
self.assertTrue(False, "should not go here")
except Exception as ex:
self.assertIn("offset must be a multiple of 4", ex.args[0])


if __name__ == "__main__":
run_tests()
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ DIPUGeneratorImpl::DIPUGeneratorImpl(at::DeviceIndex device_index)
*/
void DIPUGeneratorImpl::set_current_seed(uint64_t seed) {
seed_ = seed;
offset_ = 0;
set_offset(0);
state_need_reset_ = true;
}

Expand Down
10 changes: 8 additions & 2 deletions dipu/torch_dipu/csrc_dipu/vendor/cuda/CudaGeneratorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ class CUDAGeneratorImpl : public dipu::DIPUGeneratorImpl {
#else
auto new_rng_state = state.data_dtype_initialized<uint8_t>();
#endif
memcpy(&input_seed, new_rng_state, seed_size);
memcpy(&input_seed, new_rng_state + states_size, seed_size);
this->set_current_seed(input_seed);
int64_t philox_offset = 0;
if (!no_philox_seed) {
memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
memcpy(&philox_offset, new_rng_state + states_size + seed_size,
offset_size);
}
this->set_offset(static_cast<uint64_t>(philox_offset));

Expand Down Expand Up @@ -71,6 +72,11 @@ class CUDAGeneratorImpl : public dipu::DIPUGeneratorImpl {
state_need_reset_ = false;
}
}

void set_offset(uint64_t offset) override {
TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
DIPUGeneratorImpl::set_offset(offset);
}
};

// NOLINTNEXTLINE(readability-const-return-type)
Expand Down

0 comments on commit b638174

Please sign in to comment.