Skip to content

Commit

Permalink
[SUPA] fix crash during copy (#924)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaron20000101 authored Aug 2, 2024
1 parent ae11430 commit 9a1f901
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 9 deletions.
26 changes: 26 additions & 0 deletions dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <algorithm>

#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>

#include "csrc_dipu/aten/DIPUATenFunctions.h"
Expand Down Expand Up @@ -36,8 +37,30 @@ void setDipuCopyInstance(DIPUCopyBase* op) { dipu_copy_op() = op; }
namespace dipu {
namespace native {
namespace dipu_aten {

at::Scalar _local_scalar_dense_dipu(const at::Tensor& self) {
at::Scalar r;
#if DIPU_VENDOR_NAME_SUPA
extern void setSupaDeviceCopyInfo(void* ptr, int64_t offset);
if (self.scalar_type() == c10::ScalarType::Bool) {
// on SUPA, bool type is represents by float.
using scalar_t = c10::impl::ScalarTypeToCPPTypeT<at::kFloat>;
float value;
dipu::DIPUStream stream = dipu::getCurrentDIPUStream();
MemChecker::instance().check(self);
/* on SUPA, it can't plus offset to ptr since it is virtual address.
must set offset in advance and recaculate ptr after translating it to
physical address.
*/
setSupaDeviceCopyInfo(self.storage().data(),
self.storage_offset() * sizeof(scalar_t));
dipu::devproxy::memCopyD2HAsync(stream.rawstream(), sizeof(scalar_t),
&value, self.data_ptr());
dipu::devproxy::syncStream(stream.rawstream());
r = at::Scalar((std::abs(value) >= 1e-6f));
return r;
}
#endif
AT_DISPATCH_ALL_TYPES_AND3(
at::kHalf, at::kBool, at::kBFloat16, self.scalar_type(),
"_local_scalar_dense_dipu", [&] {
Expand All @@ -49,6 +72,9 @@ at::Scalar _local_scalar_dense_dipu(const at::Tensor& self) {
dipu::devproxy::memCopyD2H(sizeof(scalar_t), &value,
self.data_ptr<scalar_t>());
#else
#if DIPU_VENDOR_NAME_SUPA
setSupaDeviceCopyInfo(self.storage().data(), self.storage_offset()*sizeof(scalar_t));
#endif
dipu::devproxy::memCopyD2HAsync(stream.rawstream(), sizeof(scalar_t),
&value, self.data_ptr<scalar_t>());
dipu::devproxy::syncStream(stream.rawstream());
Expand Down
75 changes: 66 additions & 9 deletions dipu/torch_dipu/csrc_dipu/vendor/supa/deviceimpl.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <cstdint>

#include <supa.h>

#include <csrc_dipu/runtime/device/deviceapis.h>
Expand Down Expand Up @@ -187,6 +189,27 @@ void br_device_free(void* ptr);
void* get_phy_ptr(const void* ptr);
}

typedef struct _AddrInfo {
int64_t offset = 0;
void* ptr = nullptr;
void set(void* _ptr, int64_t _offset) {
ptr = _ptr;
offset = _offset;
}
void* get(const void* _ptr) {
if (ptr != nullptr) {
auto p = ptr;
ptr = nullptr;
return (uint8_t*)get_phy_ptr(p) + offset;
}
return get_phy_ptr(_ptr);
}
} AddrInfo;

namespace {
thread_local AddrInfo addr0, addr1;
} // namespace

DIPU_API OpStatus mallocDevice(void** p, size_t nbytes, bool throwExcepion) {
void* ptr = nullptr;
ptr = br_device_malloc(nbytes);
Expand All @@ -200,47 +223,74 @@ DIPU_API OpStatus mallocDevice(void** p, size_t nbytes, bool throwExcepion) {
return OpStatus::ERR_NOMEM;
}

/**
* @brief Set the Supa Device Copy Info into thread local object.
*
* @param ptr base virtual addr.
* @param offset offset (in bytes) plus to base addr.
*/
inline void setSupaDeviceCopyInfo(void* ptr, int64_t offset, int index = 0) {
if (index == 0) {
addr0.set(ptr, offset);
} else {
addr1.set(ptr, offset);
}
}

/**
* @brief Get the Device Ptr from virtual address and optional offset.
*
* @param src virtual address. if addr has valid data, value of src is ignored.
* @return void* physical address
*/
inline void* getDevicePtr(const void* src, int index = 0) {
if (index == 0) {
return addr0.get(src);
}
return addr1.get(src);
}

DIPU_API void freeDevice(void* p) { br_device_free(p); }

DIPU_API bool isPinnedPtr(const void* p) { return false; }

// (asynchronous) set val
DIPU_API void memSetAsync(const deviceStream_t stream, void* ptr, int val,
size_t size) {
auto phy_gpu_addr = get_phy_ptr(ptr);
auto phy_gpu_addr = getDevicePtr(ptr);
SUPA_CALL(suMemsetAsync(phy_gpu_addr, val, size, stream));
}

// (synchronous) copy from device to a device
DIPU_API void memCopyD2D(size_t nbytes, deviceId_t dstDevId, void* dst,
deviceId_t srcDevId, const void* src) {
// SUPA uses Unified Virtual Address
auto phy_src_gpu_addr = get_phy_ptr(src);
auto phy_dst_gpu_addr = get_phy_ptr(dst);
auto phy_src_gpu_addr = getDevicePtr(src, 0);
auto phy_dst_gpu_addr = getDevicePtr(dst, 1);
SUPA_CALL(suMemcpy(phy_dst_gpu_addr, phy_src_gpu_addr, nbytes,
suMemcpyDeviceToDevice));
}

// (synchronous) copy from host to a device
DIPU_API void memCopyH2D(size_t nbytes, /*deviceId_t dstDevId,*/ void* dst,
/*Host srcDev,*/ const void* src) {
auto phy_dst_gpu_addr = get_phy_ptr(dst);
auto phy_dst_gpu_addr = getDevicePtr(dst);
SUPA_CALL(suMemcpy(phy_dst_gpu_addr, src, nbytes, suMemcpyHostToDevice));
}

// (synchronous) copy from a device to host
DIPU_API void memCopyD2H(size_t nbytes, /*Host dstDev,*/ void* dst,
/*deviceId_t srcDevId,*/ const void* src) {
auto phy_src_gpu_addr = get_phy_ptr(src);
auto phy_src_gpu_addr = getDevicePtr(src);
SUPA_CALL(suMemcpy(dst, phy_src_gpu_addr, nbytes, suMemcpyDeviceToHost));
}

// (asynchronous) copy from device to a device
DIPU_API void memCopyD2DAsync(const deviceStream_t stream, size_t nbytes,
deviceId_t dstDevId, void* dst,
deviceId_t srcDevId, const void* src) {
auto phy_src_gpu_addr = get_phy_ptr(src);
auto phy_dst_gpu_addr = get_phy_ptr(dst);
auto phy_src_gpu_addr = getDevicePtr(src, 0);
auto phy_dst_gpu_addr = getDevicePtr(dst, 1);
SUPA_CALL(suMemcpyAsync(phy_dst_gpu_addr, phy_src_gpu_addr, nbytes, stream,
suMemcpyDeviceToDevice));
}
Expand All @@ -249,7 +299,7 @@ DIPU_API void memCopyD2DAsync(const deviceStream_t stream, size_t nbytes,
DIPU_API void memCopyH2DAsync(const deviceStream_t stream, size_t nbytes,
/*deviceId_t dstDevId,*/ void* dst,
/*Host srcDev,*/ const void* src) {
auto phy_dst_gpu_addr = get_phy_ptr(dst);
auto phy_dst_gpu_addr = getDevicePtr(dst);
SUPA_CALL(suMemcpyAsync(phy_dst_gpu_addr, src, nbytes, stream,
suMemcpyHostToDevice));
}
Expand All @@ -258,9 +308,16 @@ DIPU_API void memCopyH2DAsync(const deviceStream_t stream, size_t nbytes,
DIPU_API void memCopyD2HAsync(const deviceStream_t stream, size_t nbytes,
/*Host dstDev,*/ void* dst,
/*deviceId_t srcDevId,*/ const void* src) {
auto phy_src_gpu_addr = get_phy_ptr(src);
auto phy_src_gpu_addr = getDevicePtr(src);
SUPA_CALL(suMemcpyAsync(dst, phy_src_gpu_addr, nbytes, stream,
suMemcpyDeviceToHost));
}
} // end namespace devapis

namespace native::dipu_aten {
void setSupaDeviceCopyInfo(void* ptr, int64_t offset) {
devapis::setSupaDeviceCopyInfo(ptr, offset);
}
} // namespace native::dipu_aten

} // end namespace dipu

0 comments on commit 9a1f901

Please sign in to comment.