diff --git a/dipu/tests/pytorch_config_mlu.py b/dipu/tests/pytorch_config_mlu.py index 7fe6016c2..0bb50b814 100644 --- a/dipu/tests/pytorch_config_mlu.py +++ b/dipu/tests/pytorch_config_mlu.py @@ -49,6 +49,10 @@ }, # test_testing.py 'TestTestParametrizationDeviceTypeDIPU': { + # when change dipu device type to 'cuda', 'test_ops_composition_names' fail, because parameter + # passed to testclass.device_type is 'dipu', different device seems have different case numbers. + # to do: change test device_type='cuda' + 'test_ops_composition_names', 'test_unparametrized_names', 'test_make_tensor_dipu', 'test_dtypes_composition_valid', diff --git a/dipu/tests/run_camb_tests.sh b/dipu/tests/run_camb_tests.sh index 99f094007..39ce7cca2 100644 --- a/dipu/tests/run_camb_tests.sh +++ b/dipu/tests/run_camb_tests.sh @@ -19,7 +19,13 @@ function run_dipu_tests { #run_test "${PYTORCH_DIR}/test/test_utils.py" "$@" -v run_test "${PYTORCH_DIR}/test/test_unary_ufuncs.py" "$@" -v -f TestUnaryUfuncsDIPU run_test "${PYTORCH_DIR}/test/test_binary_ufuncs.py" "$@" -v -f TestBinaryUfuncsDIPU + + # need fix: random func test not throw expected err msg as check_nondeterministic_alert() needed, + # when device type is xpu it just ignore this err (should_alert= false), but device type 'cuda' will expose errors + export DIPU_PYTHON_DEVICE_AS_CUDA=false run_test "${PYTORCH_DIR}/test/test_torch.py" "$@" -v -f TestTorchDeviceTypeDIPU #--subprocess + export DIPU_PYTHON_DEVICE_AS_CUDA=true + run_test "${PYTORCH_DIR}/test/test_indexing.py" "$@" -v -f TestIndexingDIPU run_test "${PYTORCH_DIR}/test/test_indexing.py" "$@" -v -f NumpyTestsDIPU run_test "${PYTORCH_DIR}/test/test_view_ops.py" "$@" -v -f TestViewOpsDIPU diff --git a/dipu/tests/test_ops/archived/test_generator.py b/dipu/tests/test_ops/archived/test_generator.py index 58152782a..466c902e5 100644 --- a/dipu/tests/test_ops/archived/test_generator.py +++ b/dipu/tests/test_ops/archived/test_generator.py @@ -1,6 +1,7 @@ # Copyright (c) 2023, DeepLink. import torch import torch_dipu +from torch_dipu import diputype from torch_dipu.testing._internal.common_utils import create_common_tensor, TestCase, run_tests @@ -35,13 +36,13 @@ def test_torch_generator(self): assert gen.device.type == 'cpu' gen = torch.Generator("cuda") - assert gen.device.type == 'xpu' + assert gen.device.type == diputype gen = torch.Generator("cuda:0") - assert gen.device == torch.device('xpu:0') + assert gen.device == torch.device(diputype + ':0') gen = torch.Generator("dipu") - assert gen.device.type == 'xpu' + assert gen.device.type == diputype gen.manual_seed(1) assert gen.initial_seed() == 1 diff --git a/dipu/tests/test_ops/archived/test_rt_tensor.py b/dipu/tests/test_ops/archived/test_rt_tensor.py index 4240dfa66..ce4732a2a 100644 --- a/dipu/tests/test_ops/archived/test_rt_tensor.py +++ b/dipu/tests/test_ops/archived/test_rt_tensor.py @@ -124,6 +124,14 @@ def testDeviceProperties(): print("device capability: ", torch.cuda.get_device_capability(0)) print("device name: ", torch.cuda.get_device_name(0)) +def test_mem_get_info(): + import torch_dipu + from torch import cuda + minfo = cuda.mem_get_info() + d1 = torch.ones((1024, 1024 * 30), device = "cuda") + minfo = cuda.mem_get_info() + print(minfo) + def test_type(): import torch_dipu dev1 = "cuda" @@ -172,16 +180,34 @@ def test_complex_type(): zr = torch.view_as_real(z2) print(zr.cpu) +# env DIPU_PYTHON_DEVICE_AS_CUDA is default true! +def test_dipu_as_cuda_type(): + import torch_dipu + d1 = torch.device("cuda", 0) + t1 = torch.ones((1024, 1), device = 0) + print(t1) + assert(d1.type == "cuda") + assert(t1.is_cuda == True) + assert(t1.device.type == "cuda") + s1 = t1.storage() + assert(s1.device.type == "cuda") + + gen = torch.Generator("dipu") + gen.manual_seed(1) + assert gen.device.type == "cuda" + if __name__ == '__main__': for i in range(1, 2): empty1() testdevice() testDeviceProperties() + test_mem_get_info() testStream() test_record_stream() testevent() test_type() test_complex_type() + test_dipu_as_cuda_type() # need more 2 device to run # testDevice1() diff --git a/dipu/tests/test_ops/archived/test_storage.py b/dipu/tests/test_ops/archived/test_storage.py index 08e9930cf..1f7d3e59f 100644 --- a/dipu/tests/test_ops/archived/test_storage.py +++ b/dipu/tests/test_ops/archived/test_storage.py @@ -5,7 +5,8 @@ def test_stor1(): PATH1 = "./test_stor1.pth" - + stor_shared1 = torch.UntypedStorage._new_shared(3, device="cpu") + print(stor_shared1) device = "cuda:0" # args is int8, args = [[1, 0, 0, 0, 4, 0, 0, 0, 12, 0, 0, 0]] diff --git a/dipu/tests/test_stor1.pth b/dipu/tests/test_stor1.pth new file mode 100644 index 000000000..70a9ed0bd Binary files /dev/null and b/dipu/tests/test_stor1.pth differ diff --git a/dipu/torch_dipu/__init__.py b/dipu/torch_dipu/__init__.py index b819387c6..d0e6fb19a 100644 --- a/dipu/torch_dipu/__init__.py +++ b/dipu/torch_dipu/__init__.py @@ -82,13 +82,20 @@ def apply_torch_function_patch(): torch.randn = GetDeviceStaticProxy(torch.randn) torch.randn_like = GetDeviceStaticProxy(torch.randn_like) torch.randperm = GetDeviceStaticProxy(torch.randperm) + + # todo: try to automaitc check & mock funcs + torch.linspace = GetDeviceStaticProxy(torch.linspace) + if mockcuda: for attr in dipu.__all__: if hasattr(torch.cuda, attr): setattr(torch.cuda, attr, getattr(dipu, attr)) - - if attr in torch.cuda.random.__all__ and hasattr(torch.cuda.random, attr): - setattr(torch.cuda.random, attr, getattr(dipu, attr)) + if attr in torch.cuda.random.__all__ and hasattr(dipu.random_dipu, attr): + setattr(torch.cuda.random, attr, getattr(dipu.random_dipu, attr)) + if attr in torch.cuda.memory.__all__ and hasattr(dipu.memory, attr): + setattr(torch.cuda.memory, attr, getattr(dipu.memory, attr)) + # special case dipu ans cuda use different name + torch.cuda.device = dipu.devicectx # temp solution, need redesign storage diff --git a/dipu/torch_dipu/csrc_dipu/CMakeLists.txt b/dipu/torch_dipu/csrc_dipu/CMakeLists.txt index 677450e4c..f98d9d428 100644 --- a/dipu/torch_dipu/csrc_dipu/CMakeLists.txt +++ b/dipu/torch_dipu/csrc_dipu/CMakeLists.txt @@ -16,7 +16,6 @@ file(GLOB RT_SRC_FILES runtime/core/guardimpl/*.cpp runtime/core/allocator/*.cpp runtime/core/DIPU*.cpp - runtime/core/device.cpp runtime/core/MemChecker.cpp runtime/distributed/*.cpp runtime/devproxy/*.cpp @@ -79,6 +78,7 @@ add_dependencies(${DIPU_LIB} copy_include) # --------build bind in python -------------- file(GLOB BIND_SRC_FILES binding/Export*.cpp + binding/patch*.cpp ) set(BIND_FILES ${BIND_SRC_FILES} diff --git a/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp b/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp index a17f3fd04..2fe8dc941 100644 --- a/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp +++ b/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp @@ -20,9 +20,10 @@ namespace dipu { static constexpr size_t kMega = 1024 * 1024; using dipu::devapis::DIPUDeviceProperties; +using dipu::devapis::DIPUDeviceStatus; static void registerDIPUDeviceProperties(py::module& m) { - py::class_(m, "_DIPUDeviceProperties") + py::class_>(m, "_DIPUDeviceProperties") .def_readonly("name", &DIPUDeviceProperties::name) .def_readonly("major", &DIPUDeviceProperties::major) .def_readonly("minor", &DIPUDeviceProperties::minor) @@ -39,9 +40,23 @@ static void registerDIPUDeviceProperties(py::module& m) { }); } +static void registerDIPUDeviceStatus(py::module& m) { + py::class_>(m, "_DIPUDeviceStatus") + .def_readonly("free_memory", &DIPUDeviceStatus::freeGlobalMem) + .def("__repr__", [](const DIPUDeviceStatus& status) { + std::ostringstream stream; + stream << "DIPUDeviceStatus(used_memory=" << status.freeGlobalMem + << ")"; + return stream.str(); + }); +} + static void exportDevices(py::module& m) { + registerDIPUDeviceProperties(m); + registerDIPUDeviceStatus(m); // Device Management. m.attr("dipu_vendor") = dipu::VendorTypeToStr(VENDOR_TYPE); + m.attr("dipu_device_type") = DeviceTypeName(DIPU_DEVICE_TYPE, true); m.attr("dicl_backend") = DICL_BACKEND_NAME; m.def("_dipu_set_device", [](int idx) -> void { @@ -58,9 +73,19 @@ static void exportDevices(py::module& m) { devproxy::syncDevice(); return; }); - m.def("_dipu_getDeviceProperties", [](int device) -> DIPUDeviceProperties* { - return dipu::device::getDevicePropertiesFromCache(device); - }, py::return_value_policy::reference); + m.def("_dipu_getDeviceProperties", [](int device) -> std::shared_ptr { + return dipu::getDevicePropertiesFromCache(device); + }, py::arg("device")); + + /* + different with device properties, fill_status may cause creation of the device stub on the specified device, + the sub will occupy mem, so caller should always fill status after set device() + and only fill status of current device, otherwise you will create stub an other device. + */ + m.def("_dipu_getDeviceStatus", [](int device) -> std::shared_ptr { + return dipu::getDeviceStatus(device); + }, py::arg("device")); + } static void exportStream(py::module& m) { @@ -275,9 +300,11 @@ static void exportGenerator(py::module& m) { }); } +extern void patchTorchCsrcDevice(PyObject* module); + DIPU_API void exportDIPURuntime(PyObject* module) { auto m = py::handle(module).cast(); - registerDIPUDeviceProperties(m); + patchTorchCsrcDevice(module); exportDevices(m); exportStream(m); exportEvent(m); diff --git a/dipu/torch_dipu/csrc_dipu/binding/patchCsrcDevice.cpp b/dipu/torch_dipu/csrc_dipu/binding/patchCsrcDevice.cpp new file mode 100644 index 000000000..055cbe20d --- /dev/null +++ b/dipu/torch_dipu/csrc_dipu/binding/patchCsrcDevice.cpp @@ -0,0 +1,120 @@ +// Copyright (c) 2023, DeepLink. + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +#include "exportapi.h" + +namespace dipu { + +static bool PythonDeviceAsCuda = false; + +static at::DeviceType _get_dipu_python_type(const at::Device& device) { + if (device.type() == DIPU_DEVICE_TYPE && PythonDeviceAsCuda) { + return at::DeviceType::CUDA; + } + return device.type(); +} + +PyObject* _THPDevice_type(THPDevice* self, PyObject* noargs) { + HANDLE_TH_ERRORS + std::ostringstream oss; + oss << _get_dipu_python_type(self->device); + return THPUtils_packString(oss.str().c_str()); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* _THPDevice_index(THPDevice* self, PyObject* noargs) { + HANDLE_TH_ERRORS + if (self->device.has_index()) { + return THPUtils_packInt64(self->device.index()); + } else { + Py_RETURN_NONE; + } + END_HANDLE_TH_ERRORS +} + +PyObject* DIPU_THPDevice_repr(THPDevice* self) { + std::ostringstream oss; + oss << "device(type=\'" << _get_dipu_python_type(self->device) << "\'"; + if (self->device.has_index()) { + // `self->device.index()` returns uint8_t which is treated as ascii while + // printing, hence casting it to uint16_t. + // https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout + oss << ", index=" << static_cast(self->device.index()); + } + oss << ")"; + return THPUtils_packString(oss.str().c_str()); +} + + +PyObject* DIPU_THPDevice_str(THPDevice* self) { + std::ostringstream oss; + oss << _get_dipu_python_type(self->device); + return THPUtils_packString(oss.str().c_str()); +} + +static struct PyGetSetDef DIPU_THPDevice_properties[] = { + {"type", (getter)_THPDevice_type, nullptr, nullptr, nullptr}, + {"index", (getter)_THPDevice_index, nullptr, nullptr, nullptr}, + {nullptr}}; + + +/* +why use this method to patch csrc.Device: because +1. csrc.Device is a final cpython class which not support attributes mock in python layer. +2. rewrite a new DeviceType to replace THPDeviceType is not work because torch::PythonArgParser + will check the type of THPDeviceType when parse Device parameter(see csrc/utils/python_arg_parer.cpp + FunctionParameter::check() -> THPDevice_Check()) +so we replace some attributes of THPDeviceType class in c-python layer +*/ +void patchTorchCsrcDevice(PyObject* module) { + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_dict + THPDeviceType.tp_dict = nullptr; + // change Type properties + THPDeviceType.tp_getset = DIPU_THPDevice_properties; + THPDeviceType.tp_repr = (reprfunc)DIPU_THPDevice_repr; + THPDeviceType.tp_str = (reprfunc)DIPU_THPDevice_str; + + // change THPDeviceType as an overriable class need add some other prperties in PyTypeObject, + // It may cause problems and seem un-necessary, so we keep the THPDeviceType as immutable. + THPDeviceType.tp_flags = Py_TPFLAGS_DEFAULT; // | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + + if (PyType_Ready(&THPDeviceType) < 0) { + throw python_error(); + } + Py_INCREF(&THPDeviceType); + + auto m = py::handle(module).cast(); + + m.def("_get_python_device_as_cuda", []() -> bool { + return PythonDeviceAsCuda; + }); + + m.def ("_set_python_device_as_cuda", [](bool as_cuda) -> void { + PythonDeviceAsCuda = as_cuda; + }); + + // not really 'export' new type but change original THPDeviceType is enough + // if (PyModule_AddObject(module, "device", (PyObject*)&THPDeviceType) != 0) { + // throw python_error(); + // } +} +} // namespace dipu \ No newline at end of file diff --git a/dipu/torch_dipu/csrc_dipu/runtime/core/device.cpp b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUDeviceInfo.cpp similarity index 55% rename from dipu/torch_dipu/csrc_dipu/runtime/core/device.cpp rename to dipu/torch_dipu/csrc_dipu/runtime/core/DIPUDeviceInfo.cpp index 15c9cd6c7..420c9c41e 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/core/device.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUDeviceInfo.cpp @@ -5,19 +5,20 @@ #include #include -#include "./device.h" - +#include "./DIPUDeviceInfo.h" namespace dipu { -namespace device { +// anonymous ns +namespace { +using std::shared_ptr; using dipu::devapis::DIPUDeviceProperties; using c10::DeviceIndex; DeviceIndex num_gpus = -1; c10::once_flag init_flag; std::deque device_flags; -std::vector device_properties; +std::vector> device_properties; static void initDIPUContextVectors() { num_gpus = dipu::devproxy::getDeviceCount(); @@ -27,19 +28,31 @@ static void initDIPUContextVectors() { static void initDeviceProperty(DeviceIndex device_index) { DIPUDeviceProperties device_prop = dipu::devproxy::getDeviceProperties(device_index); - device_properties[device_index] = device_prop; + device_properties[device_index] = std::make_shared(device_prop); } -DIPUDeviceProperties* getDevicePropertiesFromCache(int32_t device_index) { +static inline void checkDevice(int32_t device_index) { c10::call_once(init_flag, initDIPUContextVectors); if (device_index == -1) { device_index = dipu::devproxy::current_device(); } AT_ASSERT(device_index >= 0 && device_index < num_gpus); +} + +} // end anonymous +shared_ptr getDevicePropertiesFromCache(int32_t device_index) { + checkDevice(device_index); c10::call_once(device_flags[device_index], initDeviceProperty, device_index); - return &device_properties[device_index]; + return device_properties[device_index]; +} + +shared_ptr getDeviceStatus(int32_t device_index) { + checkDevice(device_index); + + // never cache status + DIPUDeviceStatus device_prop = dipu::devproxy::getDeviceStatus(device_index); + return std::make_shared(device_prop); } -} // namespace device } // namespace dipu \ No newline at end of file diff --git a/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUDeviceInfo.h b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUDeviceInfo.h new file mode 100644 index 000000000..2e6142538 --- /dev/null +++ b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUDeviceInfo.h @@ -0,0 +1,11 @@ +// Copyright (c) 2023, DeepLink. +#include + +namespace dipu { +using dipu::devapis::DIPUDeviceProperties; +using dipu::devapis::DIPUDeviceStatus; + +DIPU_API std::shared_ptr getDevicePropertiesFromCache(int32_t device_index); +DIPU_API std::shared_ptr getDeviceStatus(int32_t device_index); + +} // namespace dipu \ No newline at end of file diff --git a/dipu/torch_dipu/csrc_dipu/runtime/core/device.h b/dipu/torch_dipu/csrc_dipu/runtime/core/device.h deleted file mode 100644 index 3ceb990b3..000000000 --- a/dipu/torch_dipu/csrc_dipu/runtime/core/device.h +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) 2023, DeepLink. -#include - -namespace dipu { - -namespace device { - -DIPU_API dipu::devapis::DIPUDeviceProperties* getDevicePropertiesFromCache(int32_t device_index); - -} // namespace device -} // namespace dipu \ No newline at end of file diff --git a/dipu/torch_dipu/csrc_dipu/runtime/device/basedef.h b/dipu/torch_dipu/csrc_dipu/runtime/device/basedef.h index d73dea87c..ee3146b5f 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/device/basedef.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/device/basedef.h @@ -77,6 +77,10 @@ typedef enum { } diclResult_t; +struct DIPUDeviceStatus { + size_t freeGlobalMem = 0; +}; + struct DIPUDeviceProperties { std::string name; size_t totalGlobalMem = 0; diff --git a/dipu/torch_dipu/csrc_dipu/runtime/device/deviceapis.h b/dipu/torch_dipu/csrc_dipu/runtime/device/deviceapis.h index 6bcb30589..5ae2e871f 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/device/deviceapis.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/device/deviceapis.h @@ -11,13 +11,14 @@ namespace dipu { extern devapis::VendorDeviceType VENDOR_TYPE; namespace devapis { -DIPU_API void initializeVendor(); +DIPU_WEAK void initializeVendor(); -DIPU_API void finalizeVendor(); +DIPU_WEAK void finalizeVendor(); DIPU_API deviceId_t current_device(); DIPU_API DIPUDeviceProperties getDeviceProperties(int32_t device_index); +DIPU_WEAK DIPUDeviceStatus getDeviceStatus(int32_t device_index); // set current device given device according to id DIPU_API void setDevice(deviceId_t devId); diff --git a/dipu/torch_dipu/csrc_dipu/runtime/devproxy/deviceproxy.cpp b/dipu/torch_dipu/csrc_dipu/runtime/devproxy/deviceproxy.cpp index ec819a5d6..9c9c29fd0 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/devproxy/deviceproxy.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/devproxy/deviceproxy.cpp @@ -3,26 +3,18 @@ #include "../core/DIPUEventPool.h" namespace dipu { -namespace devapis{ - -__attribute__((weak)) void initializeVendor() { - -} - -__attribute__((weak)) void finalizeVendor() { - -} - -} // namespace devapis - namespace devproxy { void initializeVendor() { - devapis::initializeVendor(); + if (devapis::initializeVendor) { + devapis::initializeVendor(); + } } void finalizeVendor() { - devapis::finalizeVendor(); + if (devapis::finalizeVendor) { + devapis::finalizeVendor(); + } } deviceId_t current_device() { @@ -33,6 +25,13 @@ DIPUDeviceProperties getDeviceProperties(int32_t device_index) { return devapis::getDeviceProperties(device_index); } +DIPUDeviceStatus getDeviceStatus(int32_t device_index) { + if (devapis::getDeviceStatus) { + return devapis::getDeviceStatus(device_index); + } + return DIPUDeviceStatus(); +} + // set current device given device according to id void setDevice(deviceId_t devId) { return devapis::setDevice(devId); diff --git a/dipu/torch_dipu/csrc_dipu/runtime/devproxy/deviceproxy.h b/dipu/torch_dipu/csrc_dipu/runtime/devproxy/deviceproxy.h index d0fb561e1..2a69300b9 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/devproxy/deviceproxy.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/devproxy/deviceproxy.h @@ -3,15 +3,16 @@ #include "../device/deviceapis.h" +namespace dipu { + +namespace devproxy { + using dipu::devapis::deviceId_t; using dipu::devapis::DIPUDeviceProperties; +using dipu::devapis::DIPUDeviceStatus; using dipu::devapis::EventStatus; using dipu::devapis::OpStatus; -namespace dipu { - -namespace devproxy { - DIPU_API void initializeVendor(); DIPU_API void finalizeVendor(); @@ -19,6 +20,7 @@ DIPU_API void finalizeVendor(); DIPU_API deviceId_t current_device(); DIPU_API DIPUDeviceProperties getDeviceProperties(int32_t device_index); +DIPU_API DIPUDeviceStatus getDeviceStatus(int32_t device_index); // set current device given device according to id DIPU_API void setDevice(deviceId_t devId); diff --git a/dipu/torch_dipu/csrc_dipu/runtime/rthelper.h b/dipu/torch_dipu/csrc_dipu/runtime/rthelper.h index 46659027a..c681ba039 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/rthelper.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/rthelper.h @@ -2,7 +2,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/dipu/torch_dipu/csrc_dipu/vendor/camb/cnrt_6.x/deviceimpl.cpp b/dipu/torch_dipu/csrc_dipu/vendor/camb/cnrt_6.x/deviceimpl.cpp index 5d45414ca..f216fb834 100644 --- a/dipu/torch_dipu/csrc_dipu/vendor/camb/cnrt_6.x/deviceimpl.cpp +++ b/dipu/torch_dipu/csrc_dipu/vendor/camb/cnrt_6.x/deviceimpl.cpp @@ -38,13 +38,20 @@ DIPUDeviceProperties getDeviceProperties(int32_t device_index) { DIPUDeviceProperties prop; prop.name = device_prop.name; - prop.totalGlobalMem = mem_info.physicalMemoryTotal << 20;; + prop.totalGlobalMem = mem_info.physicalMemoryTotal << 20; prop.major = major; prop.minor = minor; prop.multiProcessorCount = multi_processor_cnt; return prop; } +/* + both cndevMemoryInfo_t.physicalMemoryUsed from cndevGetMemoryUsage and cndevProcessInfo_t from cndevGetProcessInfo seems not correct, + value always zero, need further investigation. +DIPUDeviceStatus getDeviceStatus(int32_t device_index) { +} +*/ + // check last launch succ or not, throw if fail void checkLastError() { DIPU_CALLCNRT(::cnrtGetLastError()) diff --git a/dipu/torch_dipu/csrc_dipu/vendor/cuda/deviceimpl.cpp b/dipu/torch_dipu/csrc_dipu/vendor/cuda/deviceimpl.cpp index 8a89ffca9..8e72bc1bd 100644 --- a/dipu/torch_dipu/csrc_dipu/vendor/cuda/deviceimpl.cpp +++ b/dipu/torch_dipu/csrc_dipu/vendor/cuda/deviceimpl.cpp @@ -42,6 +42,12 @@ DIPUDeviceProperties getDeviceProperties(int32_t device_index) { return prop; } +DIPUDeviceStatus getDeviceStatus(int32_t device_index) { + DIPUDeviceStatus status; + cudaMemGetInfo(&status.freeGlobalMem, nullptr); + return status; +} + // in cuda_runtime_api.h // set current device given device according to id void setDevice(deviceId_t devId) { diff --git a/dipu/torch_dipu/dipu/__init__.py b/dipu/torch_dipu/dipu/__init__.py index 0acf58412..47f9455b1 100644 --- a/dipu/torch_dipu/dipu/__init__.py +++ b/dipu/torch_dipu/dipu/__init__.py @@ -2,6 +2,7 @@ from .utils import is_initialized from .device import __diputype__ as diputype from .device import __vendor__ as vendor_type +from .device import devicectx from .device import * from .random_dipu import * from .memory import * @@ -24,7 +25,7 @@ 'LongTensor', 'IntTensor', 'ShortTensor', 'ByteTensor', 'CharTensor', 'BoolTensor', # device - "can_device_access_peer", "current_device", "device", "device_count", "device_of", "synchronize", + "can_device_access_peer", "current_device", "devicectx", "device_count", "device_of", "synchronize", "get_device_name", "get_device_properties", "get_device_capability", "is_available", "set_device", "GetDeviceProxy", "GetDeviceStaticProxy", "diputype", "vendor_type", @@ -38,7 +39,7 @@ # # mem manage "reset_peak_memory_stats", "empty_cache", "memory_allocated", "memory_reserved", "max_memory_allocated", "max_memory_reserved", - # "caching_allocator_alloc", "caching_allocator_delete", "memory_summary", "memory_stats" + "mem_get_info", # "caching_allocator_alloc", "caching_allocator_delete", "memory_summary", "memory_stats" # not support mock cuda_graph now ] diff --git a/dipu/torch_dipu/dipu/device.py b/dipu/torch_dipu/dipu/device.py index c4d282810..8d85e1efd 100644 --- a/dipu/torch_dipu/dipu/device.py +++ b/dipu/torch_dipu/dipu/device.py @@ -6,50 +6,67 @@ from torch_dipu import mockcuda from torch_dipu import _C +import os __dipu__ = 'dipu' -__diputype__ = 'xpu' +__dipu_device_type__ = _C.dipu_device_type +__diputype__ = __dipu_device_type__ + +def init_dipu_device_type(forceUnset: bool = False): + global __diputype__ + _C._set_python_device_as_cuda(os.environ.get("DIPU_PYTHON_DEVICE_AS_CUDA", 'True').lower()=='true' and mockcuda and not forceUnset) + __diputype__ = "cuda" if _C._get_python_device_as_cuda() else __dipu_device_type__ + if __diputype__ == "cuda": + print("dipu device will show as cuda device. if it's not expected behavior, please set env DIPU_PYTHON_DEVICE_AS_CUDA=false") + torch._C._set_cudnn_enabled(False) + +init_dipu_device_type() + __vendor__ = _C.dipu_vendor # need update when compile _device_t = Union[torch.device, str, int, None] _C.init_resource() class _MetaDeviceType(type): - device_ = torch.device - def __instancecheck__(cls, instance): - if isinstance(instance, _MetaDeviceType.device_): - return True - return False + _torch_device = torch.device + def __instancecheck__(cls, inst): + if isinstance(inst, cls._torch_device): + return True + return False + -# torch.Device is a final class. cannot inherit # csrc/Device.cpp THPDevice_pynew: # "Device(Device device)" Device type can be Device, Long, String # "Device(c10::string_view type, int64_t? index=-1)" class _DIPUDevice(metaclass=_MetaDeviceType): @staticmethod - def __doreplace(arg): + def __replacedipu(arg): if (__dipu__ in arg): - arg = arg.replace(__dipu__, __diputype__) + arg = arg.replace(__dipu__, __dipu_device_type__) if (mockcuda and "cuda" in arg): - arg = arg.replace("cuda", __diputype__) + arg = arg.replace("cuda", __dipu_device_type__) return arg def __new__(cls, *args, **kwargs): if len(args) == 1 and isinstance(args[0], int) and mockcuda: # modify default int device type only when "mock cuda". - dev_name = __diputype__ + ":" + str(args[0]) - return _MetaDeviceType.device_(dev_name) + dev_name = __dipu_device_type__ + ":" + str(args[0]) + _device = _MetaDeviceType._torch_device(dev_name) + return _device # handle device as str if len(args) >= 1 and isinstance(args[0], str): argList = list(args) - argList[0] = cls.__doreplace(args[0]) + argList[0] = cls.__replacedipu(args[0]) args = tuple(argList) - # handle device in type key, not support int type but str and device + # handle parameter type: str, not support int type but str and device deviceValue = kwargs.get("type", None) if deviceValue != None and isinstance(deviceValue, str): - kwargs["type"] = cls.__doreplace(deviceValue) - return _MetaDeviceType.device_(*args, **kwargs) + kwargs["type"] = cls.__replacedipu(deviceValue) + _device = _MetaDeviceType._torch_device(*args, **kwargs) + return _device +# always patch torch.device = _DIPUDevice +# todo: use device_ctx & torch_function to reduce processing logic? # wrap device related func def GetDeviceProxy(rawfunc, pos = 0, name = "device", caller = "obj"): def _replaceDevice(args, kwargs): @@ -76,7 +93,7 @@ def _proxyFuncStatic(*args, **kwargs): # class __new__ always pass cls parameter to args def _proxyNewClass(cls, *args, **kwargs): args, kwargs = _replaceDevice(args, kwargs) - return rawfunc(*args, **kwargs) + return rawfunc(cls, *args, **kwargs) if caller == "static": return _proxyFuncStatic @@ -211,10 +228,13 @@ def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int] prop = get_device_properties(device) return prop.major, prop.minor - def get_device_properties(device: _device_t) -> _C._DIPUDeviceProperties: _lazy_init() device_id = _get_device_index(device, optional=True) - if device_id < 0 or device_id >= device_count(): - raise AssertionError("Invalid device id") return _C._dipu_getDeviceProperties(device_id) + + +def get_device_status(device: _device_t) -> _C._DIPUDeviceStatus: + _lazy_init() + device_id = _get_device_index(device, optional=True) + return _C._dipu_getDeviceStatus(device_id) diff --git a/dipu/torch_dipu/dipu/generator.py b/dipu/torch_dipu/dipu/generator.py index 99001e002..cf04594ed 100644 --- a/dipu/torch_dipu/dipu/generator.py +++ b/dipu/torch_dipu/dipu/generator.py @@ -7,52 +7,22 @@ class Generator: - generator: Optional[torch._C.Generator] = None - dipu_generator: Optional[torch._C.Generator] = None - device: _device - def __new__(self, device: Union[_device, str, None] = None) -> None: + def __new__(cls, device: Union[_device, str, None] = None) -> None: if device is None: - self.device = torch.device('cpu') - self.generator = torch._C.Generator(device) - return self.generator - - self.device = torch.device(device) - if self.device.type == 'cpu': - self.generator = torch._C.Generator(device) - return self.generator - elif self.device.type == diputype: - self.dipu_generator = _C._create_dipu_generator(_get_device_index(self.device.index, True)) - return self.dipu_generator + device = torch.device('cpu') + generator = torch._C.Generator(device) + return generator + + device = torch.device(device) + if device.type == 'cpu': + generator = torch._C.Generator(device) + return generator + elif device.type == diputype: + dipu_generator = _C._create_dipu_generator(_get_device_index(device.index, True)) + return dipu_generator else: - raise Exception(f"unsupport device type {self.device.type}") - - def get_state(self): - if self.generator is not None: - return self.generator.get_state() - return self.dipu_generator.get_state() - - def set_state(self, new_state): - if self.generator is not None: - self.generator.set_state(new_state) - return - self.dipu_generator.set_state(new_state) - - def manual_seed(self, seed): - if self.generator is not None: - return self.generator.manual_seed(seed) - return self.dipu_generator.manual_seed(seed) - - def seed(self): - if self.generator is not None: - return self.generator.seed() - return self.dipu_generator.seed() - - def initial_seed(self): - if self.generator is not None: - return self.generator.initial_seed() - return self.dipu_generator.initial_seed() - + raise Exception(f"unsupport device type {device.type}") def apply_generator_patch(): torch.Generator = Generator \ No newline at end of file diff --git a/dipu/torch_dipu/dipu/memory.py b/dipu/torch_dipu/dipu/memory.py index 62422cca3..340fb6f98 100644 --- a/dipu/torch_dipu/dipu/memory.py +++ b/dipu/torch_dipu/dipu/memory.py @@ -1,14 +1,15 @@ # Copyright (c) 2023, DeepLink. import collections - +from typing import Union, Tuple from torch_dipu import _C -from .device import current_device, _get_device_index, devicectx, __dipu__ +from .device import current_device, _get_device_index, devicectx, __dipu__, \ + get_device_properties, get_device_status from .utils import is_initialized from .streams import current_stream, Stream import torch - +from torch.types import Device def caching_allocator_alloc(size, device=None, stream=None): r"""Performs a memory allocation using the dipu memory allocator. @@ -111,6 +112,27 @@ def max_memory_allocated(device = None): device = torch.device(__dipu__ + ":" + str(device)) return _C.max_memory_allocated(device) + +def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]: + r"""Returns the global free and total DIPU memory occupied for a given + device + Args: + device (torch.device or int, optional): selected device. + .. note:: + See :ref:`cuda-memory-management` for more + details about GPU memory management. + """ + if device is None: + device = current_device() + total = get_device_properties(device).total_memory + free = get_device_status(device).free_memory + if free == 0: + estimated_buffer = 1800 * 1024 * 1024 + print("warnning!! seems _DIPUDeviceProperties not contain valid free mem size," + "we try to estimate free size which may be not an accurate value!") + free = total - memory_allocated(device) - estimated_buffer + return (free, total) + ## just an empty shell now def memory_stats(device=None): result = [] diff --git a/dipu/torch_dipu/dipu/storages.py b/dipu/torch_dipu/dipu/storages.py index c4ad87035..d7dfd3cff 100644 --- a/dipu/torch_dipu/dipu/storages.py +++ b/dipu/torch_dipu/dipu/storages.py @@ -35,7 +35,7 @@ def _dipu_deserialize(obj, location): else: return obj.dipu(device_idx) -# this obj is storage, so it's device.type is real diputype (xpu), not 'cuda' or 'dipu' +# this obj is storage, so it's device.type is real diputype (may be cuda or xpu depend on if set DIPU_PYTHON_DEVICE_AS_CUDA) def _dipu_tag(obj): if obj.device.type == __diputype__: return __diputype__ + str(obj.device.index) @@ -66,9 +66,12 @@ def _is_dipu_storage(self): return self.device.type == __diputype__ setattr(UntypedStorage, 'is_dipu', _is_dipu_storage) - UntypedStorage.dipu = _dipu_storage +if mockcuda: + UntypedStorage.is_cuda = UntypedStorage.is_dipu + UntypedStorage.cuda = UntypedStorage.dipu + _raw_storage_resize = UntypedStorage.resize_ def _resize(self, size: int): if self.device.type != __diputype__: @@ -78,23 +81,7 @@ def _resize(self, size: int): UntypedStorage.resize_ = _resize +_raw_untyped_storage_new = GetDeviceProxy(torch.UntypedStorage.__new__, pos = -1, caller="class_new") +UntypedStorage.__new__ = lambda cls, *args, **kwargs : \ + _raw_untyped_storage_new(cls, *args, **kwargs) -# should we mock cuda/dipu in storage api which is a low level api or should we ensure all -# upper layer cuda api are mocked so low level are always dipu storage? - -# _StorageBase.cuda/is_cuda = - -# class _MetaDIPUStor(type): -# _raw_storage = torch.UntypedStorage -# _do_stor = GetDeviceProxy(torch.UntypedStorage, pos = -1, caller="class_new") - -# def __instancecheck__(cls, inst): -# if isinstance(inst, _MetaDIPUStor._raw_storage): -# return True -# return False - -# class DIPUUntypedStorage(metaclass=_MetaDIPUStor): -# def __new__(cls, *args, **kwargs): -# return DIPUUntypedStorage._do_stor(*args, **kwargs) - -# torch.UntypedStorage = DIPUUntypedStorage