diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.cpp b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.cpp index 5de98636d..6a0729721 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.cpp +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.cpp @@ -4,6 +4,7 @@ #include +#include #include #include "csrc_dipu/aten/DIPUATenFunctions.h" @@ -36,8 +37,30 @@ void setDipuCopyInstance(DIPUCopyBase* op) { dipu_copy_op() = op; } namespace dipu { namespace native { namespace dipu_aten { + at::Scalar _local_scalar_dense_dipu(const at::Tensor& self) { at::Scalar r; +#if DIPU_VENDOR_NAME_SUPA + extern void setSupaDeviceCopyInfo(void* ptr, int64_t offset); + if (self.scalar_type() == c10::ScalarType::Bool) { + // on SUPA, bool type is represents by float. + using scalar_t = c10::impl::ScalarTypeToCPPTypeT; + float value; + dipu::DIPUStream stream = dipu::getCurrentDIPUStream(); + MemChecker::instance().check(self); + /* on SUPA, it can't plus offset to ptr since it is virtual address. + must set offset in advance and recaculate ptr after translating it to + physical address. + */ + setSupaDeviceCopyInfo(self.storage().data(), + self.storage_offset() * sizeof(scalar_t)); + dipu::devproxy::memCopyD2HAsync(stream.rawstream(), sizeof(scalar_t), + &value, self.data_ptr()); + dipu::devproxy::syncStream(stream.rawstream()); + r = at::Scalar((std::abs(value) >= 1e-6f)); + return r; + } +#endif AT_DISPATCH_ALL_TYPES_AND3( at::kHalf, at::kBool, at::kBFloat16, self.scalar_type(), "_local_scalar_dense_dipu", [&] { @@ -49,6 +72,9 @@ at::Scalar _local_scalar_dense_dipu(const at::Tensor& self) { dipu::devproxy::memCopyD2H(sizeof(scalar_t), &value, self.data_ptr()); #else +#if DIPU_VENDOR_NAME_SUPA + setSupaDeviceCopyInfo(self.storage().data(), self.storage_offset()*sizeof(scalar_t)); +#endif dipu::devproxy::memCopyD2HAsync(stream.rawstream(), sizeof(scalar_t), &value, self.data_ptr()); dipu::devproxy::syncStream(stream.rawstream()); diff --git a/dipu/torch_dipu/csrc_dipu/vendor/supa/deviceimpl.cpp b/dipu/torch_dipu/csrc_dipu/vendor/supa/deviceimpl.cpp index 8b0a421cd..716b45569 100644 --- a/dipu/torch_dipu/csrc_dipu/vendor/supa/deviceimpl.cpp +++ b/dipu/torch_dipu/csrc_dipu/vendor/supa/deviceimpl.cpp @@ -1,3 +1,5 @@ +#include + #include #include @@ -187,6 +189,27 @@ void br_device_free(void* ptr); void* get_phy_ptr(const void* ptr); } +typedef struct _AddrInfo { + int64_t offset = 0; + void* ptr = nullptr; + void set(void* _ptr, int64_t _offset) { + ptr = _ptr; + offset = _offset; + } + void* get(const void* _ptr) { + if (ptr != nullptr) { + auto p = ptr; + ptr = nullptr; + return (uint8_t*)get_phy_ptr(p) + offset; + } + return get_phy_ptr(_ptr); + } +} AddrInfo; + +namespace { +thread_local AddrInfo addr0, addr1; +} // namespace + DIPU_API OpStatus mallocDevice(void** p, size_t nbytes, bool throwExcepion) { void* ptr = nullptr; ptr = br_device_malloc(nbytes); @@ -200,6 +223,33 @@ DIPU_API OpStatus mallocDevice(void** p, size_t nbytes, bool throwExcepion) { return OpStatus::ERR_NOMEM; } +/** + * @brief Set the Supa Device Copy Info into thread local object. + * + * @param ptr base virtual addr. + * @param offset offset (in bytes) plus to base addr. + */ +inline void setSupaDeviceCopyInfo(void* ptr, int64_t offset, int index = 0) { + if (index == 0) { + addr0.set(ptr, offset); + } else { + addr1.set(ptr, offset); + } +} + +/** + * @brief Get the Device Ptr from virtual address and optional offset. + * + * @param src virtual address. if addr has valid data, value of src is ignored. + * @return void* physical address + */ +inline void* getDevicePtr(const void* src, int index = 0) { + if (index == 0) { + return addr0.get(src); + } + return addr1.get(src); +} + DIPU_API void freeDevice(void* p) { br_device_free(p); } DIPU_API bool isPinnedPtr(const void* p) { return false; } @@ -207,7 +257,7 @@ DIPU_API bool isPinnedPtr(const void* p) { return false; } // (asynchronous) set val DIPU_API void memSetAsync(const deviceStream_t stream, void* ptr, int val, size_t size) { - auto phy_gpu_addr = get_phy_ptr(ptr); + auto phy_gpu_addr = getDevicePtr(ptr); SUPA_CALL(suMemsetAsync(phy_gpu_addr, val, size, stream)); } @@ -215,8 +265,8 @@ DIPU_API void memSetAsync(const deviceStream_t stream, void* ptr, int val, DIPU_API void memCopyD2D(size_t nbytes, deviceId_t dstDevId, void* dst, deviceId_t srcDevId, const void* src) { // SUPA uses Unified Virtual Address - auto phy_src_gpu_addr = get_phy_ptr(src); - auto phy_dst_gpu_addr = get_phy_ptr(dst); + auto phy_src_gpu_addr = getDevicePtr(src, 0); + auto phy_dst_gpu_addr = getDevicePtr(dst, 1); SUPA_CALL(suMemcpy(phy_dst_gpu_addr, phy_src_gpu_addr, nbytes, suMemcpyDeviceToDevice)); } @@ -224,14 +274,14 @@ DIPU_API void memCopyD2D(size_t nbytes, deviceId_t dstDevId, void* dst, // (synchronous) copy from host to a device DIPU_API void memCopyH2D(size_t nbytes, /*deviceId_t dstDevId,*/ void* dst, /*Host srcDev,*/ const void* src) { - auto phy_dst_gpu_addr = get_phy_ptr(dst); + auto phy_dst_gpu_addr = getDevicePtr(dst); SUPA_CALL(suMemcpy(phy_dst_gpu_addr, src, nbytes, suMemcpyHostToDevice)); } // (synchronous) copy from a device to host DIPU_API void memCopyD2H(size_t nbytes, /*Host dstDev,*/ void* dst, /*deviceId_t srcDevId,*/ const void* src) { - auto phy_src_gpu_addr = get_phy_ptr(src); + auto phy_src_gpu_addr = getDevicePtr(src); SUPA_CALL(suMemcpy(dst, phy_src_gpu_addr, nbytes, suMemcpyDeviceToHost)); } @@ -239,8 +289,8 @@ DIPU_API void memCopyD2H(size_t nbytes, /*Host dstDev,*/ void* dst, DIPU_API void memCopyD2DAsync(const deviceStream_t stream, size_t nbytes, deviceId_t dstDevId, void* dst, deviceId_t srcDevId, const void* src) { - auto phy_src_gpu_addr = get_phy_ptr(src); - auto phy_dst_gpu_addr = get_phy_ptr(dst); + auto phy_src_gpu_addr = getDevicePtr(src, 0); + auto phy_dst_gpu_addr = getDevicePtr(dst, 1); SUPA_CALL(suMemcpyAsync(phy_dst_gpu_addr, phy_src_gpu_addr, nbytes, stream, suMemcpyDeviceToDevice)); } @@ -249,7 +299,7 @@ DIPU_API void memCopyD2DAsync(const deviceStream_t stream, size_t nbytes, DIPU_API void memCopyH2DAsync(const deviceStream_t stream, size_t nbytes, /*deviceId_t dstDevId,*/ void* dst, /*Host srcDev,*/ const void* src) { - auto phy_dst_gpu_addr = get_phy_ptr(dst); + auto phy_dst_gpu_addr = getDevicePtr(dst); SUPA_CALL(suMemcpyAsync(phy_dst_gpu_addr, src, nbytes, stream, suMemcpyHostToDevice)); } @@ -258,9 +308,16 @@ DIPU_API void memCopyH2DAsync(const deviceStream_t stream, size_t nbytes, DIPU_API void memCopyD2HAsync(const deviceStream_t stream, size_t nbytes, /*Host dstDev,*/ void* dst, /*deviceId_t srcDevId,*/ const void* src) { - auto phy_src_gpu_addr = get_phy_ptr(src); + auto phy_src_gpu_addr = getDevicePtr(src); SUPA_CALL(suMemcpyAsync(dst, phy_src_gpu_addr, nbytes, stream, suMemcpyDeviceToHost)); } } // end namespace devapis + +namespace native::dipu_aten { +void setSupaDeviceCopyInfo(void* ptr, int64_t offset) { + devapis::setSupaDeviceCopyInfo(ptr, offset); +} +} // namespace native::dipu_aten + } // end namespace dipu