diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml index 357fac475..7fec5f4da 100755 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml @@ -2148,75 +2148,37 @@ return out; interface: diopiNorm(ctx, out, self, p, dimDiopiSize); -- schema: "to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)" +# wrap_diopi_cast_dtype has no corresponding aten op and not registed, it's just a diopi func wrapper. +# use this tricky method to support call multiple diopi-op in one aten-op +- schema: "wrap_diopi_cast_dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)" register_op: False custom_code_at_the_beginning: | auto out = at::empty_like(self, self.options().dtype(dtype)); interface: diopiCastDtype(ctx, out, self); - custom_code_before_return: | - if (memory_format.has_value()) { - auto out1 = at::empty_like(out, out.options(), memory_format.value()); - at::copy(out1, out, non_blocking); - out = out1; - } - if (!non_blocking) { - dipu::getCurrentDIPUStream().synchronize(); - } -- schema: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) +# a diopi func wrapper. +- schema: wrap_diopi_copy_inp(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + register_op: False no_device_check_args: [self, src] - device: [not_for_any_now] #todo - ins: [srcTemp] + interface: diopiCopyInp(ctx, src, self) + +# this copy_ aten op may use both diopiCastDtype and diopiCopyInp. it's a proxy/composite op +- schema: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + dummy_call_diopi: True custom_fallback: True + device: [cuda, camb, ascend, droplet, supa] custom_code_at_the_beginning: | - dipu::DIPUGuard guard(self.is_cpu() ? src.device() : self.device()); - auto stream = dipu::getCurrentDIPUStream(); - auto srcTemp = self.dtype() == src.dtype() ? src : src.to(self.dtype()); - srcTemp = (srcTemp.numel() == self.numel()) ? srcTemp : srcTemp.expand(self.sizes()); - if (non_blocking) { - const bool is_default_stream = dipu::getDefaultDIPUStream() == stream; - if (self.is_cpu()) { - if (self.options().pinned_memory()) { - self.record_stream(stream); - } - } else if (!is_default_stream){ - self.record_stream(stream); - } - if (srcTemp.is_cpu()) { - if (srcTemp.options().pinned_memory()) { - srcTemp.record_stream(stream); - } - } else if (!is_default_stream) { - srcTemp.record_stream(stream); - } - } - if (self.device().type() != srcTemp.device().type()) { - srcTemp = srcTemp.is_contiguous(self.suggest_memory_format()) ? srcTemp : srcTemp.contiguous(self.suggest_memory_format()); - if (srcTemp.is_cpu() && (!self.is_cpu())) { - // c2d - dipu::devproxy::memCopyH2DAsync(stream.rawstream(), self.nbytes(), self.data_ptr(), srcTemp.data_ptr()); - } else if ((!srcTemp.is_cpu()) && self.is_cpu()) { - // d2c - dipu::devproxy::memCopyD2HAsync(stream.rawstream(), self.nbytes(), self.data_ptr(), srcTemp.data_ptr()); - } - if (!non_blocking) { - dipu::getCurrentDIPUStream().synchronize(); - } - - return self; - } + dipu::getDipuCopyInstance()->run(self, src, non_blocking); + return self; + // need add [composite] attr? the code behind this is useless. interface: diopiCopyInp(ctx, srcTemp, self) - custom_code_before_return: | - if (!non_blocking) { - dipu::getCurrentDIPUStream().synchronize(); - } +# vendor who has no fully implemented diopi and proper fallback DIPUCopy sub-class - schema: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) - no_device_check_args: [self, src] - custom_fallback: True dummy_call_diopi: True custom_code_at_the_beginning: | return custom_fallback_dipu_copy_(self, src, non_blocking); + device: [topsrider] interface: diopiCopyInp(ctx, src, self) - schema: _amp_foreach_non_finite_check_and_unscale_(at::TensorList self, Tensor(b!) found_inf, Tensor inv_scale) -> void diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py b/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py index 3a6cf76f6..7eda79b15 100644 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py @@ -15,6 +15,7 @@ #include "csrc_dipu/profiler/profiler.h" #include #include "CustomFallbackFunctions.hpp" +#include "csrc_dipu/aten/ops/DIPUCopy.hpp" $header_include_code diff --git a/dipu/tests/python/unittests/test_copy.py b/dipu/tests/python/unittests/test_copy.py index 1d77897f4..472dfa314 100644 --- a/dipu/tests/python/unittests/test_copy.py +++ b/dipu/tests/python/unittests/test_copy.py @@ -54,6 +54,14 @@ def test_copy_(self): # print(f"dst_dipu = {dst_dipu}") # assert False, "copy_ test fail" self.assertEqual(dst_cpu, dst_dipu.cpu()) + + def test_hollow_device_copy_(self): + device = "cuda" + t1 = torch.rand((6, 4), device=device) + dst1 = t1.as_strided((2, 2), (4, 1)) + src = torch.rand((2, 2), device=device) + dst1.copy_(src) + self.assertEqual(dst1.cpu(), src.cpu()) if __name__ == "__main__": diff --git a/dipu/torch_dipu/csrc_dipu/CMakeLists.txt b/dipu/torch_dipu/csrc_dipu/CMakeLists.txt index a88c94bea..ae1d1e97b 100644 --- a/dipu/torch_dipu/csrc_dipu/CMakeLists.txt +++ b/dipu/torch_dipu/csrc_dipu/CMakeLists.txt @@ -20,6 +20,8 @@ set(DIPU_AUTOGEN_DIOPI_WRAPPER_SCRIPT "${DIPU_AUTOGEN_DIOPI_WRAPPER_SOURCE_DIR}/autogen_diopi_wrapper.py") set(DIPU_AUTOGEN_DIOPI_WRAPPER_CONFIG "${DIPU_AUTOGEN_DIOPI_WRAPPER_SOURCE_DIR}/diopi_functions.yaml") +set(DIPU_AUTOGEN_DIOPI_WRAPPER_TEMPLATE + "${DIPU_AUTOGEN_DIOPI_WRAPPER_SOURCE_DIR}/diopi_wrapper_template.py") set(DIPU_AUTOGENED_KERNELS_CPP "${CMAKE_CURRENT_SOURCE_DIR}/aten/ops/AutoGenedKernels.cpp") add_custom_command( @@ -31,7 +33,8 @@ add_custom_command( --print_op_arg True --fun_config_dict '{\"current_device\": \"${UsedVendor}\"}' DEPENDS ${DIPU_AUTOGEN_DIOPI_WRAPPER_SCRIPT} - ${DIPU_AUTOGEN_DIOPI_WRAPPER_CONFIG}) + ${DIPU_AUTOGEN_DIOPI_WRAPPER_CONFIG} + ${DIPU_AUTOGEN_DIOPI_WRAPPER_TEMPLATE}) add_custom_target(autogen_diopi_kernels_cpp DEPENDS ${DIPU_AUTOGENED_KERNELS_CPP}) add_dependencies(${DIPU_AUTOGENED_KERNELS} autogen_diopi_kernels_cpp) diff --git a/dipu/torch_dipu/csrc_dipu/aten/DIPUATenFunctions.h b/dipu/torch_dipu/csrc_dipu/aten/DIPUATenFunctions.h index 1e4e28d4e..0398ea5f7 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/DIPUATenFunctions.h +++ b/dipu/torch_dipu/csrc_dipu/aten/DIPUATenFunctions.h @@ -33,9 +33,6 @@ struct DIPUATenFunctions { c10::optional device_opt, c10::optional pin_memory_opt); - static at::Tensor& copy_(at::Tensor& self, const at::Tensor& src, - bool non_blocking); - static const at::Tensor& resize_( const at::Tensor& self, at::IntArrayRef size, c10::optional memory_format); diff --git a/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.cpp b/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.cpp index 70dc839bd..6898e83b0 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.cpp +++ b/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.cpp @@ -13,7 +13,6 @@ #include #include #include -#include using dnative = dipu::native::DIPUATenFunctions; @@ -228,7 +227,7 @@ at::Tensor wrapper_DIPU___copy_from_and_resize(const at::Tensor& self, const at::Tensor& wrapper_resize_( const at::Tensor& self, at::IntArrayRef size, c10::optional memory_format) { - // add guard for device switch. + // DeviceGuard omitted because resize_ has guard within itself. dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__); return dnative::resize_(self, size, memory_format); } diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/CopyKernel.cpp b/dipu/torch_dipu/csrc_dipu/aten/ops/CopyKernel.cpp deleted file mode 100644 index 5731706cd..000000000 --- a/dipu/torch_dipu/csrc_dipu/aten/ops/CopyKernel.cpp +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) 2023, DeepLink. -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -using at::Layout; -using c10::device_or_default; -using c10::IntArrayRef; -using c10::layout_or_default; -using c10::StorageImpl; -using c10::TensorImpl; -using dipu::devapis::deviceId_t; - -namespace dipu::native { - -// need abstract cast strategy before copy, some device(eg camb) not support all -// types, -inline at::Tensor cast2CompatibleDeviceTensor(const at::Tensor& hostTensor) { - return hostTensor; -} -inline int64_t getCopyBytes(const at::Tensor& dst, const at::Tensor& src) { - if (dst.nbytes() != - src.nbytes()) { // outer bytes must same. different type is unsuported - TORCH_CHECK(false, "dipu copy with different size is not allowed"); - } - int64_t dstBytes = dst.unsafeGetTensorImpl()->unsafe_storage().nbytes(); - int64_t srcBytes = src.unsafeGetTensorImpl()->unsafe_storage().nbytes(); - // a view one + a real stor one is supported - return srcBytes < dstBytes ? srcBytes : dstBytes; -} - -static void copy_H2D(const at::Tensor& dst, const at::Tensor& src, - bool non_blocking) { - int64_t nbytes = getCopyBytes(dst, src); - dipu::DIPUStream stream = dipu::getCurrentDIPUStream(); - - auto src_cast = cast2CompatibleDeviceTensor(src); - void* src_ptr = src_cast.data_ptr(); - void* dst_ptr = dst.data_ptr(); - - MemChecker::instance().check(dst); - dipu::devproxy::memCopyH2DAsync(stream.rawstream(), nbytes, dst_ptr, src_ptr); - if (!non_blocking) { - dipu::devproxy::syncStream(stream.rawstream()); - } -} - -static void copy_D2H(const at::Tensor& dst, const at::Tensor& src, - bool non_blocking) { - int64_t nbytes = getCopyBytes(dst, src); - dipu::DIPUStream stream = dipu::getCurrentDIPUStream(); - - void* src_ptr = src.data_ptr(); - void* dst_ptr = dst.data_ptr(); - - MemChecker::instance().check(src); - dipu::devproxy::memCopyD2HAsync(stream.rawstream(), nbytes, dst_ptr, src_ptr); - if (!non_blocking) { - dipu::devproxy::syncStream(stream.rawstream()); - } -} - -inline bool isDiffStrides(const IntArrayRef stride1, - const IntArrayRef stride2) { - if (stride1.size() != stride2.size()) { - return true; - } - for (auto i = 0; i < stride1.size(); i++) { - if (stride1[i] != stride2[i]) { - return true; - } - } - return false; -} - -// 1. expand, 2. patial view. 3. type cast. -inline bool canDirectCopy(const at::Tensor& dst, const at::Tensor& src) { - // assume layout always = not suppport Sparse layout - TORCH_CHECK(dst.options().layout() == c10::Layout::Strided, - "only Strided layout is supported"); - - int64_t srcBytes = src.unsafeGetTensorImpl()->unsafe_storage().nbytes(); - int64_t dstBytes = dst.unsafeGetTensorImpl()->unsafe_storage().nbytes(); - if (srcBytes != dstBytes || dst.numel() != src.numel() || - dst.options().dtype() != src.options().dtype()) { - return false; - } - if (isDiffStrides(dst.strides(), src.strides())) { - return false; - } - // view(with no-zero offset) direct copy may cause err(not sure how long real - // stor data should be copyed) not supported - if (dst.storage_offset() != 0 || src.storage_offset() != 0) { - return false; - } - // even tensors have zero offset and same stride/type cannot do simple safe - // direct copy because we cannot simply decide how much data will be copyed - // from raw stor (unless check stride). so we always return false now. need - // enhance in future, because always copy with the help of cpu is toooo0 slow. - // **** check if copy safely using tensor.nbytes() when is_contiguous() = - // true. - return false; -} - -static void copy_D2D(const at::Tensor& dst, const at::Tensor& src, - bool non_blocking) { - int64_t nbytes = getCopyBytes(dst, src); - dipu::DIPUStream stream = dipu::getCurrentDIPUStream(); - - void* src_ptr = src.data_ptr(); - void* dst_ptr = dst.data_ptr(); - - MemChecker::instance().check(src); - MemChecker::instance().check(dst); - dipu::devproxy::memCopyD2DAsync(stream.rawstream(), nbytes, - dst.device().index(), dst_ptr, - src.device().index(), src_ptr); - if (!non_blocking) { - dipu::devproxy::syncStream(stream.rawstream()); - } -} - -inline void doRealCp(at::Tensor& self, const at::Tensor& src, - bool non_blocking) { - if (dipu::isDeviceTensor(self) && !dipu::isDeviceTensor(src)) { - // src is cpu. - copy_H2D(self, src, non_blocking); - } else if (!dipu::isDeviceTensor(self) && dipu::isDeviceTensor(src)) { - // self is cpu. - copy_D2H(self, src, non_blocking); - } else { // device to device - copy_D2D(self, src, non_blocking); - } -} - -// self is dest -// not handle storage offset, need? -at::Tensor& DIPUATenFunctions::copy_(at::Tensor& self, const at::Tensor& src, - bool non_blocking) { - if (self.numel() == 0) { - return self; - } - // save tensor dim name - c10::optional names = src.opt_names(); - if (names.has_value()) { - internal_set_names_inplace(self, names); - } - if (!canDirectCopy(self, src)) { - at::Tensor src_cpu = src; - // src to cpu - if (dipu::isDeviceTensor(src)) { - src_cpu = at::empty_strided(src.sizes(), src.strides(), - src.options().device(c10::DeviceType::CPU)); - // src storage size may bigger than src_cpu's if src is a partial view. - // but not smaller. because src_cpu use same stride as src. - // src -> src_cpu - doRealCp(src_cpu, src, non_blocking); - } - - if (dipu::isDeviceTensor(self)) { - at::Tensor dst_cpu = - at::empty_strided(self.sizes(), self.strides(), - self.options().device(c10::DeviceType::CPU)); - doRealCp(dst_cpu, self, non_blocking); - // proxy to cpu to handle different type/view problem - dst_cpu.copy_(src_cpu); - - doRealCp(self, dst_cpu, non_blocking); - } else { // self is cpu - self.copy_(src_cpu); - } - } else { - doRealCp(self, src, non_blocking); - } - return self; -} - -at::Scalar DIPUATenFunctions::_local_scalar_dense_dipu(const at::Tensor& self) { - at::Scalar r; - AT_DISPATCH_ALL_TYPES_AND2( - at::kHalf, at::kBool, self.scalar_type(), "_local_scalar_dense_dipu", - [&] { - scalar_t value; - dipu::DIPUStream stream = dipu::getCurrentDIPUStream(); - MemChecker::instance().check(self); - dipu::devproxy::memCopyD2HAsync(stream.rawstream(), sizeof(scalar_t), - &value, self.data_ptr()); - dipu::devproxy::syncStream(stream.rawstream()); - r = at::Scalar(value); - }); - return r; -} -} // namespace dipu::native diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctions.hpp b/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctions.hpp index f38d2b140..7de896f58 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctions.hpp +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctions.hpp @@ -1,7 +1,7 @@ +// Copyright (c) 2023, DeepLink. #pragma once #include "csrc_dipu/aten/RegisterDIPU.hpp" -#include #include "OpUtils.hpp" @@ -80,7 +80,7 @@ static at::Tensor& custom_fallback_dipu__index_put_impl_( return self; } -static ::std::tuple +static ::std::tuple custom_fallback_dipu_native_batch_norm_out( const at::Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, @@ -115,7 +115,7 @@ custom_fallback_dipu_native_batch_norm_out( running_var_opt.value().copy_(running_var_cpu.value()); } - return std::tie(out, save_mean, save_invstd); + return {out, save_mean, save_invstd}; } static at::Tensor custom_fallback_dipu_convolution_overrideable( @@ -313,40 +313,11 @@ custom_fallback_dipu_native_batch_norm_backward( grad_weight.copy_(std::get<1>(at_out)); grad_bias.copy_(std::get<2>(at_out)); - return std::tie(grad_input, grad_weight, grad_bias); + return {grad_input, grad_weight, grad_bias}; } -static at::Tensor& custom_fallback_dipu_copy_(at::Tensor& self, - const at::Tensor& src, - bool non_blocking) { - DIPU_OP_LOG_WARNING_ONCE("custom fallback to cpu, name=copy_" << std::endl); - dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__); - static bool use_slow_copy = (std::getenv("DIPU_USE_SLOW_COPY") != nullptr); - dipu::DIPUGuard guard(self.is_cpu() ? src.device() : self.device()); - if (non_blocking) { - auto stream = dipu::getCurrentDIPUStream(); - const bool is_default_stream = dipu::getDefaultDIPUStream() == stream; - if (self.is_cpu()) { - if (self.options().pinned_memory()) { - self.record_stream(stream); - } - } else if (!is_default_stream) { - self.record_stream(stream); - } - if (src.is_cpu()) { - if (src.options().pinned_memory()) { - src.record_stream(stream); - } - } else if (!is_default_stream) { - src.record_stream(stream); - } - } - if (use_slow_copy) { - return dipu::native::DIPUATenFunctions::copy_(self, src, non_blocking); - } else { - return dipu::getDipuCopyInplace()->run(self, src, non_blocking); - } -} +at::Tensor& custom_fallback_dipu_copy_(at::Tensor& self, const at::Tensor& src, + bool non_blocking); void custom_fallback_dipu__amp_foreach_non_finite_check_and_unscale_( at::TensorList scaled_grads, at::Tensor& found_inf, diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctionsForCopy.cpp b/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctionsForCopy.cpp new file mode 100644 index 000000000..fdcf9c969 --- /dev/null +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctionsForCopy.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2023, DeepLink. +#include + +#include "csrc_dipu/aten/RegisterDIPU.hpp" +#include "csrc_dipu/aten/ops/DIPUCopy.hpp" + +namespace dipu { +namespace native { + +at::Tensor& custom_fallback_dipu_copy_(at::Tensor& self, const at::Tensor& src, + bool non_blocking) { + DIPU_OP_LOG_WARNING_ONCE("custom fallback to dipu copy, name=copy_" + << std::endl); + static DIPUCopyInpOnCPU onCpuCopy; + + dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__); + onCpuCopy.run(self, src, non_blocking); + return self; +} + +} // namespace native +} // namespace dipu diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.cpp b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.cpp new file mode 100644 index 000000000..6efeea9a9 --- /dev/null +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2023, DeepLink. + +#include "DIPUCopy.hpp" + +#include + +#include + +#include + +namespace dipu { + +static DIPUCopyInpOnDIOPI default_copy_inplace_op; + +auto& dipu_copy_op() { + static DIPUCopyBase* dipu_copy_op_ = &default_copy_inplace_op; + return dipu_copy_op_; +} + +DIPUCopyBase* getDipuCopyInstance() { + TORCH_CHECK(dipu_copy_op(), "dipu copy inplace not registered"); + return dipu_copy_op(); +} + +void setDipuCopyInstance(DIPUCopyBase* op) { dipu_copy_op() = op; } + +} // namespace dipu + +namespace dipu::native { +at::Scalar DIPUATenFunctions::_local_scalar_dense_dipu(const at::Tensor& self) { + at::Scalar r; + AT_DISPATCH_ALL_TYPES_AND2( + at::kHalf, at::kBool, self.scalar_type(), "_local_scalar_dense_dipu", + [&] { + scalar_t value; + dipu::DIPUStream stream = dipu::getCurrentDIPUStream(); + MemChecker::instance().check(self); + dipu::devproxy::memCopyD2HAsync(stream.rawstream(), sizeof(scalar_t), + &value, self.data_ptr()); + dipu::devproxy::syncStream(stream.rawstream()); + r = at::Scalar(value); + }); + return r; +} +} // namespace dipu::native diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.hpp b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.hpp new file mode 100644 index 000000000..285ea2171 --- /dev/null +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.hpp @@ -0,0 +1,516 @@ +// Copyright (c) 2023, DeepLink. +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace dipu { +namespace native { +// NOTICE: these 2 func defined in AutoGenedKernels.cpp +// if dipu autogen support header file gen, remove this +at::Tensor dipu_wrap_diopi_cast_dtype(const at::Tensor& src, + at::ScalarType dtype); + +// if dipu autogen support proxy one torch op to multiple diopi op, remove +// this. +at::Tensor& dipu_wrap_diopi_copy_inp(at::Tensor& dst, const at::Tensor& src, + bool non_blocking); + +} // namespace native + +enum class DIPUCopyType { + // src and dest tensor in one device + D2Self, + // from one device to another device.Not fully tested + D2OtherD, + // from device to host + D2H, + // from host to device + H2D, +}; + +// Align with pytorch's behavior, see TensorIterator.cpp compute_mem_overlaps() +inline void checkOverlap(const at::Tensor& dst, const at::Tensor& src) { + assert_no_internal_overlap(dst); + assert_no_partial_overlap(dst, src); +} + +inline void tryRecordStream(const at::Tensor& tensor, DIPUStream& curStream, + bool is_default_stream) { + if (tensor.is_cpu() && tensor.options().pinned_memory()) { + tensor.record_stream(curStream); + } else if (!is_default_stream) { + tensor.record_stream(curStream); + } +} + +inline DIPUCopyType getCopyType(const at::Tensor& dst, const at::Tensor& src) { + bool isSrcDevice = dipu::isDeviceTensor(src); + bool isDstDevice = dipu::isDeviceTensor(dst); + if (!isSrcDevice) { + return DIPUCopyType::H2D; // this op not handle h2h, dest always device + } else if (!isDstDevice) { + return DIPUCopyType::D2H; // here src always device + } else if (src.device().index() != dst.device().index()) { + return DIPUCopyType::D2OtherD; + } + return DIPUCopyType::D2Self; +} + +inline int64_t getMemCopyBytes(const at::Tensor& dst, const at::Tensor& src, + bool nonOverlappingAndDense) { + if (dst.nbytes() != + src.nbytes()) { // outer bytes must same. different type is unsuported + TORCH_CHECK(false, "mem copy with different tensor size is not allowed"); + } + if (nonOverlappingAndDense) { + return dst.nbytes(); + } + int64_t dstBytes = dst.unsafeGetTensorImpl()->unsafe_storage().nbytes(); + int64_t srcBytes = src.unsafeGetTensorImpl()->unsafe_storage().nbytes(); + return std::min(srcBytes, dstBytes); +} + +inline void memCopyH2D(const at::Tensor& dst, const at::Tensor& src, + dipu::DIPUStream& stream, int64_t nbytes) { + void* src_ptr = src.data_ptr(); + void* dst_ptr = dst.data_ptr(); + + MemChecker::instance().check(dst); + dipu::devproxy::memCopyH2DAsync(stream.rawstream(), nbytes, dst_ptr, src_ptr); +} + +inline void memCopyD2H(const at::Tensor& dst, const at::Tensor& src, + dipu::DIPUStream& stream, int64_t nbytes) { + void* src_ptr = src.data_ptr(); + void* dst_ptr = dst.data_ptr(); + + MemChecker::instance().check(src); + dipu::devproxy::memCopyD2HAsync(stream.rawstream(), nbytes, dst_ptr, src_ptr); +} + +inline void memCopyD2D(const at::Tensor& dst, const at::Tensor& src, + dipu::DIPUStream& stream, int64_t nbytes) { + void* src_ptr = src.data_ptr(); + void* dst_ptr = dst.data_ptr(); + + MemChecker::instance().check(src); + MemChecker::instance().check(dst); + dipu::devproxy::memCopyD2DAsync(stream.rawstream(), nbytes, + dst.device().index(), dst_ptr, + src.device().index(), src_ptr); +} + +inline void memCopy(const at::Tensor& dst, const at::Tensor& src, + dipu::DIPUStream& stream, DIPUCopyType copyType, + bool needMemCpSync, bool nonOverlappingAndDense) { + int64_t nbytes = getMemCopyBytes(dst, src, nonOverlappingAndDense); + switch (copyType) { + case DIPUCopyType::H2D: + // src is cpu. + memCopyH2D(dst, src, stream, nbytes); + break; + case DIPUCopyType::D2H: + // dst is cpu. + memCopyD2H(dst, src, stream, nbytes); + break; + default: // device to device + memCopyD2D(dst, src, stream, nbytes); + } + // this sync is different with copy_ non_blocking, it's used inside one copy + // op when doing a intermidiate cpu copy after some stream op to guarantee the + // cpu copy get correct data. + if (needMemCpSync) { + dipu::devproxy::syncStream(stream.rawstream()); + } +} + +class CopyParamsInfo { + public: + DIPUCopyType copyType_; + DIPUStream curStream_; + // basic info + // if cast needed + bool sameDtype_ = false; + // determine if expand needed. + bool sameSize_ = false; + bool sameStride_ = false; + bool denseAndNoOverlap_ = false; + + // composite info, can direct mem copy + bool directMemCopy_ = false; + + void recomputeTensorsInfo(const at::Tensor& dst, const at::Tensor& src) { + sameDtype_ = dst.scalar_type() == src.scalar_type(); + sameSize_ = dst.sizes().equals(src.sizes()); + sameStride_ = dst.strides().equals(src.strides()); + denseAndNoOverlap_ = dst.is_non_overlapping_and_dense() && + src.is_non_overlapping_and_dense(); + directMemCopy_ = + sameDtype_ && sameSize_ && sameStride_ && denseAndNoOverlap_; + } + + explicit CopyParamsInfo(const at::Tensor& dst, const at::Tensor& src, + const DIPUStream& curStream) { + // assume layout always = not suppport Sparse layout + TORCH_CHECK(dst.options().layout() == c10::Layout::Strided, + "only Strided layout is supported"); + copyType_ = getCopyType(dst, src); + curStream_ = curStream; + + recomputeTensorsInfo(dst, src); + } + + void updateCopyType(DIPUCopyType copyType) { copyType_ = copyType; } +}; + +class DIPUCopyBase { + public: + DIPUCopyBase() = default; + virtual ~DIPUCopyBase() = default; + // throw(any type excep) + virtual void run(at::Tensor& dst, const at::Tensor& src, + bool non_blocking) = 0; +}; + +/* +NOTICE: if input, output tensor occupy same storage size and has same +mem-format, DIPUCopyInplace will directly call mem_copy; if not, it call +copyNodirectXX. + +DiopiCast: means call separate diopiCast func, it's a forward compatible +solutions because some vendor's DiopiCopy not support cast. new DiopiCopy api +require cast/ +*/ +template +class DIPUCopyInplace : public DIPUCopyBase { + public: + DIPUCopyInplace() = default; + void run(at::Tensor& dst, const at::Tensor& src, bool non_blocking) override { + TORCH_CHECK(dst.defined(), "dst is undefined"); + TORCH_CHECK(src.defined(), "src is undefined"); + if (dst.numel() == 0 || dst.is_same(src)) { + return; + } + const c10::DeviceGuard guard(dst.is_cpu() ? src.device() : dst.device()); + auto curStream = dipu::getCurrentDIPUStream(); + + auto info = CopyParamsInfo(dst, src, curStream); + // Exit early if dst and src are views of the same data + if ((dst.is_alias_of(src) && dst.storage_offset() == src.storage_offset() && + info.sameStride_ && info.sameDtype_)) { + return; + } + checkOverlap(dst, src); + if (native::dumpOpArgLevel() > 1) { + std::cout << " DIPUCopyInplace.run: dst:" << native::dumpArg(dst) + << std::endl; + std::cout << " DIPUCopyInplace.run:: src:" << native::dumpArg(src) + << std::endl; + } + + // recordBeforeCopy + if (non_blocking) { + const bool is_default_stream = dipu::getDefaultDIPUStream() == curStream; + tryRecordStream(dst, curStream, is_default_stream); + tryRecordStream(src, curStream, is_default_stream); + } + + copyAll(dst, src, non_blocking, info); + // syncAfterCopy + if (!non_blocking) { + dipu::devapis::syncStream(curStream.rawstream()); + } + } + + protected: + /* + NOTICE: the memory area of the dst tensor (contains hollow area actually not + belong to tensor) will be totally overwrited by the same-size src mem area. + support copy between 2 tensor with same stride and dtype, no-dense and + overlapped tensors are also supported. but the 2 cannot both be view (will + casue data outside mem area be overwrited). + */ + void doDirectMemFill(at::Tensor& dst, const at::Tensor& src, + DIPUStream& curStream, DIPUCopyType copyType, + bool needMemCpSync = true) { + if (dst.is_view() && src.is_view()) { + TORCH_CHECK(false, "doDirectMemFill cannot support all view-view copy"); + } + memCopy(dst, src, curStream, copyType, needMemCpSync, false); + } + + // support mem copy between 2 nonOverlappingAndDense tensor with same stride + // and dtype. both 2 can be view. + void doDirectMemCopy(at::Tensor& dst, const at::Tensor& src, + DIPUStream& curStream, DIPUCopyType copyType, + bool needMemCpSync = true) { + if (native::dumpOpArgLevel() > 0) { + printf("--%-50s %-30s \n", "[copy_]:", "doDirectMemCopy"); + } + + memCopy(dst, src, curStream, copyType, needMemCpSync, true); + } + + at::Tensor makeSameStrideTensor(const at::Tensor& src, DIPUStream& curStream, + at::Device newDevice, + bool willBackfillSrc = false) { + if (src.is_contiguous(c10::MemoryFormat::ChannelsLast) || + src.is_contiguous()) { + auto sameAsSrc = at::empty(src.sizes(), src.options().device(newDevice), + src.suggest_memory_format()); + return sameAsSrc; + } else { + // empty_strided is much expensive than empty_memory_format(). + // see src/ATen/EmptyTensor.cpp computeStorageNbytes() + auto sameAsSrc = at::empty_strided(src.sizes(), src.strides(), + src.options().device(newDevice)); + // prefill newTensor to support backfill in future. + if (willBackfillSrc && !src.is_non_overlapping_and_dense()) { + doDirectMemFill(sameAsSrc, src, curStream, getCopyType(sameAsSrc, src)); + } + return sameAsSrc; + } + } + + // NOTICE: this func maximize leverage device copy (D2Self) + // as relay in d2h, h2d, d2d copy. cannot used in device copy(D2Self). + void doDeviceRelayCopy(at::Tensor& dst, const at::Tensor& src, + bool non_blocking, CopyParamsInfo& info) { + switch (info.copyType_) { + // create dst_device (relay, same stride) + // 1. direct dst_cpu/otherdevice -> dst_device (is view). + // 2. src_device -> dst_device. 3. direct dst_device -> + // dst_cpu/otherdevice. + case DIPUCopyType::D2OtherD: + case DIPUCopyType::D2H: { + auto curCopyType = info.copyType_; + // same stride as dst. + // TODO(fandaoyi):: check if D2OtherD need change device guard. + auto dstInDevSrc = + makeSameStrideTensor(dst, info.curStream_, src.device(), true); + info.updateCopyType(DIPUCopyType::D2Self); + copyNodirectOnDevice(dstInDevSrc, src, non_blocking, info); + doDirectMemFill(dst, dstInDevSrc, info.curStream_, curCopyType, true); + } break; + // create src_device (relay, same stride) + // direct src_cpu -> src_device, src_device -> dst(device) + case DIPUCopyType::H2D: { + auto srcInDstdev = + makeSameStrideTensor(src, info.curStream_, dst.device()); + doDirectMemFill(srcInDstdev, src, info.curStream_, DIPUCopyType::H2D); + info.updateCopyType(DIPUCopyType::D2Self); + copyNodirectOnDevice(dst, srcInDstdev, non_blocking, info); + } break; + default: + TORCH_CHECK(false, + "doDeviceRelayCopy not support one device device, it's a " + "proxy method"); + } + } + + // NOTICE: doDeviceRelayCopy need create a relay tensor having same stride as + // the dst/src. it's expensive if the tensor is a view with big hollow, so + // supply this simple wrap method to help d2h, h2d, d2d copy. cannot used in + // device copy(D2Self). logical approach: + // 1. create dst_contig. 2. create src_contig and src -> src_contig. + // 3. direct src_contig -> dst_contig 4. dst_contig -> dst + // (TODO(fandaoyi): automatic use) + void doContigTensorRelayCopy(at::Tensor& dst, const at::Tensor& src, + bool non_blocking, CopyParamsInfo& info) { + switch (info.copyType_) { + case DIPUCopyType::D2OtherD: + case DIPUCopyType::D2H: { + // 1. create dst_contig. same device. + auto dstContig = + dst.is_contiguous() + ? dst + : at::empty_like(dst, c10::MemoryFormat::Contiguous); + // TODO(fandaoyi): check if D2OtherD need change device guard. + auto newInfo = CopyParamsInfo(dstContig, src, info.curStream_); + if (newInfo.directMemCopy_) { + doDirectMemCopy(dstContig, src, newInfo.curStream_, + newInfo.copyType_); + } else { + // equivalent as logical approach: + // 2. create dst contig in src Device and do src -> dst_contigs_2. + // 3. direct: dst_contigs_2(D) -> dst_contig(cpu/otherD). + doDeviceRelayCopy(dstContig, src, non_blocking, newInfo); + } + // 4. dst_contig -> dst (in same device/cpu), this operation need + // recurse call kernel, direcet copy cannot handle it. + if (!dstContig.is_same(dst)) { + dst.copy_(dstContig); + } + } break; + case DIPUCopyType::H2D: { + // 2. create src_contig and src -> src_contig (both cpu). + auto srcContig = src.contiguous(c10::MemoryFormat::Contiguous); + auto newInfo = CopyParamsInfo(dst, srcContig, info.curStream_); + if (newInfo.directMemCopy_) { + doDirectMemCopy(dst, srcContig, newInfo.curStream_, + newInfo.copyType_); + } + // equivalent as logical approach: + // 1. create src_contig_2(D). 3. direct src_contig(CPU) -> src_contig_2 + // (D). + // 4. src_contig_2 (device) -> dst (device), + doDeviceRelayCopy(dst, srcContig, non_blocking, newInfo); + } break; + default: + TORCH_CHECK(false, + "doDeviceRelayCopy not support one device device, it's a " + "proxy method"); + } + } + + /* + NOTICE: + d2h: direct src (device) -> src_cpu. src_cpu -> dst (cpu) + h2d: direct dst (device) -> dst_cpu (if view).. src (cpu) -> dst_cpu. + direct dst_cpu -> dst (device) + d2d: direct src (device) -> src_cpu. direct dst (device) -> dst_cpu (if + view). src_cpu -> dst_cpu. direct dst_cpu -> dst (device), very very + slow. this can handle any case, it's fallback solution. + */ + void doCpuRelayCopy(at::Tensor& dst, const at::Tensor& src, + DIPUStream& curStream, bool non_blocking) { + if (native::dumpOpArgLevel() > 0) { + printf("--%-50s %-30s \n", "[copy_]:", "doCpuRelayCopy"); + } + + at::Tensor src_cpu = src; + if (dipu::isDeviceTensor(src)) { + src_cpu = makeSameStrideTensor(src, curStream, + c10::Device(c10::DeviceType::CPU), false); + // src storage size may bigger than src_cpu's when src is a partial view. + // but not smaller. because src_cpu use same stride as src. + // src -> src_cpu + doDirectMemFill(src_cpu, src, curStream, DIPUCopyType::D2H); + } + + if (dipu::isDeviceTensor(dst)) { + auto dst_cpu = makeSameStrideTensor( + dst, curStream, c10::Device(c10::DeviceType::CPU), true); + // proxy to cpu to handle different type/view problem + dst_cpu.copy_(src_cpu); + // TODO(fandaoyi): ?? need further check ??? + // need force sync doDirectMemFill & slow down performance. + // seems dipu CachedAllocator will recycle storage of temp tensor + // when the tensor instance leave scope, even the stream(default stream) + // not finish (not sure). is it a correct behavior ?? + // function doDirectMemCopy has same problem! + doDirectMemFill(dst, dst_cpu, curStream, DIPUCopyType::H2D, true); + return; + } + // dst is cpu + dst.copy_(src_cpu); + } + + // NOTICE: handle no-direct mem copy on one device, dipu has a simple + // configurable template strategy which use DIOPI copy/cast correctly. if + // vendor has no-complete implementation of DIOPI copy/cast. please override + // copy_nodirect_device to decide the case needed to be executed by diopiCopy + // and proxy other case back to 'doCpuRelayCopy' which contain a slow + // implementaion. + + // 1. type cast. 2. expand/bcast. 3.1 special no-contiguous (stride + // hollow/overlap) 3.2 mem-format no-contiguous (storage contiguous but not + // nchw), we don't handle this + virtual void copyNodirectOnDevice(at::Tensor& dst, const at::Tensor& src, + bool non_blocking, CopyParamsInfo& info) { + if (DiopiCast) { + at::Tensor tmpSrc = src; + if (!info.sameDtype_) { + tmpSrc = native::dipu_wrap_diopi_cast_dtype(src, dst.scalar_type()); + info.recomputeTensorsInfo(dst, tmpSrc); + } + // after cast + if (info.directMemCopy_) { + doDirectMemCopy(dst, tmpSrc, info.curStream_, info.copyType_, + !tmpSrc.is_same(src)); + } else if (DiopiCopy) { + native::dipu_wrap_diopi_copy_inp(dst, tmpSrc, non_blocking); + } else { + doCpuRelayCopy(dst, src, info.curStream_, non_blocking); + } + } else if (DiopiCopy) { // !DiopiCast + native::dipu_wrap_diopi_copy_inp(dst, src, non_blocking); + } else { + doCpuRelayCopy(dst, src, info.curStream_, non_blocking); + } + } + + // NOTICE: handle no-direct mem copy between different devices, dipu has + // default strategy which use a intermidiate tensor, it's slow. vendor who has + // more efficient p2p device copy can override it (eg: device has unified + // addressing and supports passing in different device addresses to one kernel + // can use copyNodirectOnDevice() to do 'between-device-copy') + virtual void copyNodirectBetweenDevices(at::Tensor& dst, + const at::Tensor& src, + bool non_blocking, + CopyParamsInfo& info) { + if (DiopiCopy) { + // doContigTensorRelayCopy(dst, src, non_blocking, info); + doDeviceRelayCopy(dst, src, non_blocking, info); + return; + } + // if diopiCopy = false, direct do cpu copy is best. + doCpuRelayCopy(dst, src, info.curStream_, non_blocking); + } + + // NOTICE: copy no-direct mem copy between cpu and device, dipu has default + // strategy use intermidiate tensor, it's slow. vendor who has more efficient + // solution can override it. + virtual void copyNodirectDeviceHost(at::Tensor& dst, const at::Tensor& src, + bool non_blocking, CopyParamsInfo& info) { + if (DiopiCopy) { // try to maximum leverage device copy, + // doContigTensorRelayCopy(dst, src, non_blocking, info); + doDeviceRelayCopy(dst, src, non_blocking, info); + return; + } + // if diopiCopy = false, direct do cpu copy is best. + doCpuRelayCopy(dst, src, info.curStream_, non_blocking); + } + + // overriding this func is possible but not recommended + virtual void copyAll(at::Tensor& dst, const at::Tensor& src, + bool non_blocking, CopyParamsInfo& info) { + at::Tensor tmpSrc = src; + if (!info.sameSize_) { + tmpSrc = src.expand_as(dst); + info.recomputeTensorsInfo(tmpSrc, dst); + } + if (info.directMemCopy_) { + doDirectMemCopy(dst, tmpSrc, info.curStream_, info.copyType_, + info.copyType_ != DIPUCopyType::D2Self); + return; + } + switch (info.copyType_) { + case DIPUCopyType::D2Self: + copyNodirectOnDevice(dst, tmpSrc, non_blocking, info); + break; + case DIPUCopyType::D2OtherD: + copyNodirectBetweenDevices(dst, tmpSrc, non_blocking, info); + break; + default: + copyNodirectDeviceHost(dst, tmpSrc, non_blocking, info); + } + } +}; +using DIPUCopyInpOnCPU = DIPUCopyInplace; +using DIPUCopyInpOnDIOPI = DIPUCopyInplace; +using DIPUCopyInpOnDIOPIWithCast = DIPUCopyInplace; + +DIPUCopyBase* getDipuCopyInstance(); + +void setDipuCopyInstance(DIPUCopyBase* op); + +} // namespace dipu diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp b/dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp index 8f0b19cb5..213f3ff3e 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp @@ -1,5 +1,8 @@ #pragma once +#include +#include + namespace dipu::native { inline bool checkDiopiReturnValue() { diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/StorageShapeKernel.cpp b/dipu/torch_dipu/csrc_dipu/aten/ops/StorageShapeKernel.cpp index 81ece6856..4b33d68fa 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/ops/StorageShapeKernel.cpp +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/StorageShapeKernel.cpp @@ -58,6 +58,7 @@ static inline TensorImpl* _resize_impl_dipu_(TensorImpl* self, IntArrayRef size, if (self->sizes() == size && (!stride || self->strides() == stride)) { return self; } + const DIPUGuard device_guard(self->device()); // need add guard to support device change. const auto itemsize = self->dtype().itemsize(); const auto storage_offset = self->storage_offset(); diff --git a/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUCopyInplace.cpp b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUCopyInplace.cpp deleted file mode 100644 index b3eb9a011..000000000 --- a/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUCopyInplace.cpp +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright (c) 2023, DeepLink. -#include "DIPUCopyInplace.h" - -#include - -#include - -#include -#include -#include - -namespace dipu { - -at::Tensor& DIPUCopyInplace::run(at::Tensor& self, const at::Tensor& src, - bool non_blocking) { - TORCH_CHECK(self.defined(), "self is undefined"); - TORCH_CHECK(src.defined(), "src is undefined"); - - c10::optional names = src.opt_names(); - if (names.has_value()) { - internal_set_names_inplace(self, names); - } - - if (self.numel() == 0 || self.is_same(src)) { - return self; - } - - // Exit early if self and src are views of the same data - const bool is_same_data = - (self.is_alias_of(src) && self.storage_offset() == src.storage_offset() && - self.strides().equals(src.strides()) && - self.sizes().equals(src.sizes()) && - self.scalar_type() == src.scalar_type()); - if (is_same_data) { - return self; - } - - auto iter = at::TensorIteratorConfig() - .add_output(self) - .add_input(src) - .resize_outputs(false) - .check_all_same_dtype(false) - .check_all_same_device(false) - .build(); - if (iter.numel() == 0) { - return self; - } - - c10::Device dst_device = iter.device(0); - c10::Device src_device = iter.device(1); - // 1. copy between devices - if (dst_device.type() == DIPU_DEVICE_TYPE && - src_device.type() == DIPU_DEVICE_TYPE) { - return copy_between_devices(iter, self, src, non_blocking); - } - - // 2. copy between cpu and device, same dtype and shape, contiguous - bool same_dtype = iter.dtype(0) == iter.dtype(1); - if (same_dtype && iter.is_contiguous()) { - return copy_contiguous(iter, self, src, non_blocking); - } - - // 3. copy between cpu and device, different dtype or view - return copy_uncontiguous(iter, self, src, non_blocking); -} - -at::Tensor& DIPUCopyInplace::copy_between_devices(at::TensorIterator& iter, - at::Tensor& self, - const at::Tensor& src, - bool non_blocking) { - int64_t numel = iter.numel(); - c10::Device dst_device = iter.device(0); - c10::Device src_device = iter.device(1); - - bool same_type = iter.dtype(0) == iter.dtype(1); - bool memcpy_eligible = same_type && iter.is_contiguous(); - if (!memcpy_eligible) { - return native::DIPUATenFunctions::copy_(self, src, non_blocking); - } - - void* dst_ptr = iter.data_ptr(0); - void* src_ptr = iter.data_ptr(1); - if (src_ptr == dst_ptr && src_device == dst_device) { - return self; - } - - size_t size = numel * iter.element_size(0); - dipu::DIPUStream stream = dipu::getCurrentDIPUStream(); - dipu::devproxy::memCopyD2DAsync(stream.rawstream(), size, dst_device.index(), - dst_ptr, src_device.index(), src_ptr); - - if (!non_blocking) { - dipu::devproxy::syncStream(stream.rawstream()); - } - return self; -} - -at::Tensor& DIPUCopyInplace::copy_contiguous(at::TensorIterator& iter, - at::Tensor& self, - const at::Tensor& src, - bool non_blocking) { - c10::Device dst_device = iter.device(0); - c10::Device src_device = iter.device(1); - - int64_t nbytes = iter.numel() * iter.element_size(0); - dipu::DIPUStream stream = dipu::getCurrentDIPUStream(); - if (dst_device.type() == DIPU_DEVICE_TYPE && src_device.is_cpu()) { - dipu::devproxy::memCopyH2DAsync(stream.rawstream(), nbytes, - iter.data_ptr(0), iter.data_ptr(1)); - } else if (dst_device.is_cpu() && src_device.type() == DIPU_DEVICE_TYPE) { - dipu::devproxy::memCopyD2HAsync(stream.rawstream(), nbytes, - iter.data_ptr(0), iter.data_ptr(1)); - } else { - TORCH_CHECK(false, "unsupported devices in copy_"); - } - - if (!non_blocking) { - dipu::devproxy::syncStream(stream.rawstream()); - } - return self; -} - -at::Tensor& DIPUCopyInplace::copy_uncontiguous(at::TensorIterator& iter, - at::Tensor& self, - const at::Tensor& src, - bool non_blocking) { - auto& dst = iter.tensor(0); - at::Tensor dst_contig; - at::Tensor src_contig; - if (iter.device_type(0) == DIPU_DEVICE_TYPE || non_blocking) { - dst_contig = dst.is_contiguous() - ? dst - : at::empty_like(dst, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - src_contig = iter.tensor(1).to(iter.dtype(0)).expand_as(dst).contiguous(); - } else { - bool same_type = iter.dtype(0) == iter.dtype(1); - dst_contig = (dst.is_contiguous() && same_type) - ? dst - : at::empty_like(dst, iter.dtype(1), - LEGACY_CONTIGUOUS_MEMORY_FORMAT); - src_contig = iter.tensor(1).expand_as(dst).contiguous(); - } - // perform a same-dtype copy on contiguous tensors - TORCH_CHECK(dst_contig.sizes().equals(src_contig.sizes())); - TORCH_CHECK(dst_contig.scalar_type() == src_contig.scalar_type()); - dst_contig.copy_(src_contig, non_blocking); - - // if necessary, copy back into dst - if (!dst_contig.is_same(dst)) { - TORCH_CHECK(dst_contig.device() == dst.device()); - dst.copy_(dst_contig, non_blocking); - } - return self; -} - -static DIPUCopyInplace default_copy_inplace_op; -static DIPUCopyInplace* dipu_copy_inplace_op = nullptr; - -DIPUCopyInplace* getDipuCopyInplace() { - TORCH_CHECK(dipu_copy_inplace_op, "dipu copy inplace not registered"); - return dipu_copy_inplace_op; -} - -void setDipuCopyInplace(DIPUCopyInplace* op) { - if (dipu_copy_inplace_op == nullptr) { - dipu_copy_inplace_op = op; - } else if (dipu_copy_inplace_op == &default_copy_inplace_op) { - dipu_copy_inplace_op = op; - } -} - -static int32_t default_init = []() { - setDipuCopyInplace(&default_copy_inplace_op); - return 1; -}(); - -} // namespace dipu diff --git a/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUCopyInplace.h b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUCopyInplace.h deleted file mode 100644 index 8d2d889bd..000000000 --- a/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUCopyInplace.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2023, DeepLink. -#pragma once - -#include -#include -#include - -namespace dipu { - -class DIPUCopyInplace { - public: - DIPUCopyInplace() = default; - virtual ~DIPUCopyInplace() = default; - - virtual at::Tensor& run(at::Tensor& self, const at::Tensor& src, - bool non_blocking); - - // copy between devices - // 1. dtype & shape & stride all equal, use memCopyD2DAsync - // 2. use DIPUATenFunctions::copy_, proxy device tensor to cpu to handle - // different dtype/view problem - virtual at::Tensor& copy_between_devices(at::TensorIterator& iter, - at::Tensor& self, - const at::Tensor& src, - bool non_blocking); - - // copy between cpu and device, dtype & shape & stride all equal - // 1. host to device, use memCopyH2DAsync - // 2. device to host, use memCopyD2HAsync - virtual at::Tensor& copy_contiguous(at::TensorIterator& iter, - at::Tensor& self, const at::Tensor& src, - bool non_blocking); - - // copy between cpu and device, different dtype or view - virtual at::Tensor& copy_uncontiguous(at::TensorIterator& iter, - at::Tensor& self, const at::Tensor& src, - bool non_blocking); -}; - -DIPUCopyInplace* getDipuCopyInplace(); -void setDipuCopyInplace(DIPUCopyInplace* op); - -} // namespace dipu diff --git a/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUStream.h b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUStream.h index bd6aa2e3e..b2d9e58ae 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUStream.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUStream.h @@ -17,19 +17,18 @@ namespace dipu { class DIPU_API DIPUStream { public: - enum Unchecked { UNCHECKED }; - explicit DIPUStream(c10::Stream stream) : stream_(stream) { + explicit DIPUStream(c10::Stream stream) + : stream_(stream), initialized_(true) { TORCH_CHECK(stream_.device_type() == dipu::DIPU_DEVICE_TYPE); } - explicit DIPUStream(Unchecked, c10::Stream stream) : stream_(stream) {} - explicit DIPUStream(devapis::deviceId_t devidx, c10::StreamId stream_id) - : DIPUStream(Unchecked::UNCHECKED, - c10::Stream(c10::Stream::UNSAFE, + : DIPUStream(c10::Stream(c10::Stream::UNSAFE, c10::Device(dipu::DIPU_DEVICE_TYPE, devidx), stream_id)) {} + explicit DIPUStream() : DIPUStream(-1, 0) { initialized_ = false; } + ~DIPUStream() {} bool operator==(const DIPUStream& other) const noexcept { @@ -83,6 +82,7 @@ class DIPU_API DIPUStream { private: c10::Stream stream_; + bool initialized_; }; DIPU_API DIPUStream getDIPUStreamFromPool(c10::DeviceIndex device = -1); diff --git a/dipu/torch_dipu/csrc_dipu/vendor/camb/CambCopyInplace.cpp b/dipu/torch_dipu/csrc_dipu/vendor/camb/CambCopyInplace.cpp new file mode 100644 index 000000000..255148219 --- /dev/null +++ b/dipu/torch_dipu/csrc_dipu/vendor/camb/CambCopyInplace.cpp @@ -0,0 +1,126 @@ +// Copyright (c) 2023, DeepLink. + +#include + +#include + +#include +#include +#include +#include + +namespace dipu { + +// NOTICE: diopicamb copy has many restriction ^#^ +// this code based on diopi cnnl_helper.cpp gCnnlCastDataTypeMapping and +// only cover commonly used cases, leaving diopi developer to optimize. +namespace { +struct HashCnnlCastDType { + size_t operator()(const std::vector& vec) const { + size_t ret = 0; + for (auto it : vec) { + ret = (ret ^ static_cast(it)) + 0x9e3779b9 + (ret << 6) + + (ret >> 2); + } + return ret; + } +}; + +const static std::unordered_set, HashCnnlCastDType> + cnnlCastDataTypeMapping{ + {{diopi_dtype_bool, diopi_dtype_int32}}, + {{diopi_dtype_bool, diopi_dtype_float16}}, + {{diopi_dtype_bool, diopi_dtype_float32}}, + + {{diopi_dtype_int8, diopi_dtype_int16}}, + {{diopi_dtype_int8, diopi_dtype_int32}}, + {{diopi_dtype_int8, diopi_dtype_float16}}, + {{diopi_dtype_int8, diopi_dtype_float32}}, + + {{diopi_dtype_uint8, diopi_dtype_int32}}, + {{diopi_dtype_uint8, diopi_dtype_int64}}, + {{diopi_dtype_uint8, diopi_dtype_float16}}, + {{diopi_dtype_uint8, diopi_dtype_float32}}, + + {{diopi_dtype_int16, diopi_dtype_int32}}, + {{diopi_dtype_int16, diopi_dtype_float16}}, + {{diopi_dtype_int16, diopi_dtype_float32}}, + // no uint16 cast + + {{diopi_dtype_int32, diopi_dtype_bool}}, + {{diopi_dtype_int32, diopi_dtype_int8}}, + {{diopi_dtype_int32, diopi_dtype_int16}}, + {{diopi_dtype_int32, diopi_dtype_int64}}, + {{diopi_dtype_int32, diopi_dtype_float16}}, + {{diopi_dtype_int32, diopi_dtype_float32}}, + + {{diopi_dtype_uint32, diopi_dtype_int64}}, + {{diopi_dtype_uint32, diopi_dtype_uint64}}, + + {{diopi_dtype_int64, diopi_dtype_int32}}, + {{diopi_dtype_int64, diopi_dtype_uint32}}, + {{diopi_dtype_int64, diopi_dtype_float16}}, + {{diopi_dtype_int64, diopi_dtype_float32}}, + + {{diopi_dtype_uint64, diopi_dtype_uint32}}, + + // CNNL_CAST_HALF_TO_FLOAT_INF = 129, /*!< Converts half to float for + // amp training. */ + {{diopi_dtype_float16, diopi_dtype_bool}}, + {{diopi_dtype_float16, diopi_dtype_int8}}, + {{diopi_dtype_float16, diopi_dtype_uint8}}, + {{diopi_dtype_float16, diopi_dtype_int16}}, + {{diopi_dtype_float16, diopi_dtype_int32}}, + {{diopi_dtype_float16, diopi_dtype_int64}}, + {{diopi_dtype_float16, diopi_dtype_float32}}, + + // CNNL_CAST_FLOAT_TO_HALF_IEEE754 = 219, /*!< Converts float to half + // for ieee754. */ + {{diopi_dtype_float32, diopi_dtype_bool}}, + {{diopi_dtype_float32, diopi_dtype_int8}}, + {{diopi_dtype_float32, diopi_dtype_uint8}}, + {{diopi_dtype_float32, diopi_dtype_int16}}, + {{diopi_dtype_float32, diopi_dtype_int32}}, + {{diopi_dtype_float32, diopi_dtype_int64}}, + {{diopi_dtype_float32, diopi_dtype_float16}}, + {{diopi_dtype_float32, diopi_dtype_float64}}, + + {{diopi_dtype_float64, diopi_dtype_float32}}, + }; +} // namespace + +using dipu::native::dipu_wrap_diopi_copy_inp; +class CambCopyInplace : public DIPUCopyInpOnDIOPI { + public: + CambCopyInplace() = default; + ~CambCopyInplace() = default; + + void copyNodirectOnDevice(at::Tensor& dst, const at::Tensor& src, + bool non_blocking, CopyParamsInfo& info) override { + diopiDtype_t dstDtype = dipu::diopi_helper::toDiopiDtype(dst.scalar_type()); + diopiDtype_t srcDtype = dipu::diopi_helper::toDiopiDtype(src.scalar_type()); + // diopiCopy(cnnlTranspose_v2) cannot handle float64 stride based + // copy, failed test: TestTorchDeviceType.test_memory_format_to + bool noSupportedTranspose = !dst.is_contiguous() && !src.is_contiguous() && + dstDtype == diopi_dtype_float64; + bool noSupportedDtype = dst.is_complex() || src.is_complex(); + // cnnl only handle limited cast type. + bool noSupportedCast = !info.sameDtype_ && + cnnlCastDataTypeMapping.find({srcDtype, dstDtype}) == + cnnlCastDataTypeMapping.end(); + + if (noSupportedTranspose || noSupportedDtype || noSupportedCast) { + doCpuRelayCopy(dst, src, info.curStream_, non_blocking); + } else { + DIPUCopyInplace::copyNodirectOnDevice(dst, src, non_blocking, info); + } + } +}; + +static CambCopyInplace camb_copy_inplace; +static int32_t camb_init = []() { + setDipuCopyInstance(&camb_copy_inplace); + return 1; +}(); + +} // namespace dipu diff --git a/dipu/torch_dipu/csrc_dipu/vendor/cuda/CUDACopyInplace.cpp b/dipu/torch_dipu/csrc_dipu/vendor/cuda/CUDACopyInplace.cpp index faf619c8d..3b119d36b 100644 --- a/dipu/torch_dipu/csrc_dipu/vendor/cuda/CUDACopyInplace.cpp +++ b/dipu/torch_dipu/csrc_dipu/vendor/cuda/CUDACopyInplace.cpp @@ -2,70 +2,50 @@ #include +#include #include #include -#include #include namespace dipu { -at::Tensor& copy_(at::Tensor& self, const at::Tensor& src, bool non_blocking) { - if (self.numel() == 0) { - return self; - } - - dipu::DIPUStream stream = getCurrentDIPUStream(); - ::diopiContext context(stream.rawstream()); - auto ctx = &context; - - ::diopiConstTensorHandle_t srcDiopiTensorHandle = - dipu::diopi_helper::toDiopiTensorHandle(src); - ::diopiTensorHandle_t selfDiopiTensorHandle = - dipu::diopi_helper::toDiopiTensorHandle(self); - ::diopiError_t ret = - ::diopiCopyInp(ctx, srcDiopiTensorHandle, selfDiopiTensorHandle); - TORCH_CHECK(ret == ::diopiSuccess, __FILE__, ":", __LINE__, - R"(::diopiCopyInp(ctx, src, dst);)", " error, error code is ", - ret, "error message is ", diopiGetLastErrorString()); - - if (!non_blocking) { - dipu::devapis::syncStream(stream.rawstream()); - } - return self; -} - -class CUDACopyInplace : public DIPUCopyInplace { +using dipu::native::dipu_wrap_diopi_copy_inp; +class CUDACopyInplace : public DIPUCopyInpOnDIOPI { public: CUDACopyInplace() = default; ~CUDACopyInplace() = default; - at::Tensor& run(at::Tensor& self, const at::Tensor& src, - bool non_blocking) override { - return copy_(self, src, non_blocking); - } - - at::Tensor& copy_between_devices(at::TensorIterator& iter, at::Tensor& self, - const at::Tensor& src, - bool non_blocking) override { - return copy_(self, src, non_blocking); - } - - at::Tensor& copy_contiguous(at::TensorIterator& iter, at::Tensor& self, - const at::Tensor& src, - bool non_blocking) override { - return copy_(self, src, non_blocking); + // diopi-cuda copy use aten, so it can handle between-device case. + void copyNodirectBetweenDevices(at::Tensor& dst, const at::Tensor& src, + bool non_blocking, + CopyParamsInfo& info) override { + dipu_wrap_diopi_copy_inp(dst, src, non_blocking); } +}; - at::Tensor& copy_uncontiguous(at::TensorIterator& iter, at::Tensor& self, - const at::Tensor& src, - bool non_blocking) override { - return copy_(self, src, non_blocking); +// vendor which has incomplete diopiCopy implementation need write a subclass +// and override copyNodirectOnDevice like this. +/* +class VendorCopyInplcae: public DIPUCopyInpOnDIOPI { +public: + VendorCopyInplcae() = default; + ~VendorCopyInplcae() = default; + void copyNodirectOnDevice(at::Tensor& dst, const at::Tensor& src, + bool non_blocking, CopyParamsInfo& info) override { + check_if_diopi_copy_can_handle: { + dipu_wrap_diopi_copy_inp(self, src, non_blocking); + or + DIPUCopyInplace::copyNodirectOnDevice(XXX); + } else { + doCpuRelayCopy(...); + } } }; +*/ static CUDACopyInplace cuda_copy_inplace; static int32_t cuda_init = []() { - setDipuCopyInplace(&cuda_copy_inplace); + setDipuCopyInstance(&cuda_copy_inplace); return 1; }(); diff --git a/dipu/torch_dipu/csrc_dipu/vendor/cuda/communiatorimpl.cpp b/dipu/torch_dipu/csrc_dipu/vendor/cuda/communiatorimpl.cpp index 4ae90d631..3ac321a48 100644 --- a/dipu/torch_dipu/csrc_dipu/vendor/cuda/communiatorimpl.cpp +++ b/dipu/torch_dipu/csrc_dipu/vendor/cuda/communiatorimpl.cpp @@ -109,6 +109,7 @@ DIPU_API diclResult_t diclReduceScatter( NCCL_THROW(ncclReduceScatter(sendBuf, recvBuf, recvCount, ncclDataType[datatype], ncclOp[reduceOp], comm, stream)); + return DICL_SUCCESS; } DIPU_API diclResult_t diclSend(void* sendbuff, size_t count, diff --git a/dipu/torch_dipu/csrc_dipu/vendor/supa/copyinplace.cpp b/dipu/torch_dipu/csrc_dipu/vendor/supa/copyinplace.cpp index 436332ea3..0b84a9e8a 100644 --- a/dipu/torch_dipu/csrc_dipu/vendor/supa/copyinplace.cpp +++ b/dipu/torch_dipu/csrc_dipu/vendor/supa/copyinplace.cpp @@ -1,68 +1,34 @@ // Copyright (c) 2023, DeepLink. +#include #include #include -#include #include namespace dipu { namespace supa { -at::Tensor& copy_(at::Tensor& self, const at::Tensor& src, bool non_blocking) { - if (self.numel() == 0) { - return self; - } - dipu::DIPUStream stream = getCurrentDIPUStream(); - ::diopiContext context(stream.rawstream()); - auto ctx = &context; - ::diopiConstTensorHandle_t srcDiopiTensorHandle = - dipu::diopi_helper::toDiopiTensorHandle(src); - ::diopiTensorHandle_t selfDiopiTensorHandle = - dipu::diopi_helper::toDiopiTensorHandle(self); - ::diopiError_t ret = - ::diopiCopyInp(ctx, srcDiopiTensorHandle, selfDiopiTensorHandle); - TORCH_CHECK(ret == ::diopiSuccess, __FILE__, ":", __LINE__, - R"(::diopiCopyInp(ctx, src, dst);)", " error, error code is ", - ret, "error message is ", diopiGetLastErrorString()); - // TODO(caikun): remove syncStream when cache allocator is ready - if (non_blocking) { - dipu::devapis::syncStream(stream.rawstream()); - } - return self; -} +using dipu::native::dipu_wrap_diopi_copy_inp; -class SUPACopyInplace : public DIPUCopyInplace { +// supa's existing implementaion same as cuda, it proxy all copy case to diopi, +// it's different with diopiCopy doc's requirement (only handle device copy), +// so we change it's behavior as only do device copy. +class SUPACopyInplace : public DIPUCopyInpOnDIOPI { public: SUPACopyInplace() = default; ~SUPACopyInplace() = default; - at::Tensor& run(at::Tensor& self, const at::Tensor& src, - bool non_blocking) override { - return copy_(self, src, non_blocking); - } - - at::Tensor& copy_between_devices(at::TensorIterator& iter, at::Tensor& self, - const at::Tensor& src, - bool non_blocking) override { - return copy_(self, src, non_blocking); - } - - at::Tensor& copy_contiguous(at::TensorIterator& iter, at::Tensor& self, - const at::Tensor& src, - bool non_blocking) override { - return copy_(self, src, non_blocking); - } - - at::Tensor& copy_uncontiguous(at::TensorIterator& iter, at::Tensor& self, - const at::Tensor& src, - bool non_blocking) override { - return copy_(self, src, non_blocking); + // assume it can handle between device. + void copyNodirectBetweenDevices(at::Tensor& dst, const at::Tensor& src, + bool non_blocking, + CopyParamsInfo& info) override { + dipu_wrap_diopi_copy_inp(dst, src, non_blocking); } }; static SUPACopyInplace copy_inplace; -static int32_t suap_copy_inplace_init = []() { - setDipuCopyInplace(©_inplace); +static int32_t supa_copy_inplace_init = []() { + setDipuCopyInstance(©_inplace); return 1; }();