Skip to content

Commit

Permalink
change workspace name and add ERROR MACRO
Browse files Browse the repository at this point in the history
  • Loading branch information
Gong-air committed Sep 5, 2024
1 parent 2514828 commit 608ca55
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 46 deletions.
76 changes: 33 additions & 43 deletions dipu/torch_dipu/csrc_dipu/vendor/droplet/pccl.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "pccl.h"
#include <cstddef>
#include <stdexcept>
#include "pcclcommon.h"
#include <c10/core/ScalarType.h>

Expand All @@ -12,31 +13,29 @@


namespace {
template <const char* workspaceApi, typename... Args>
void callPcclImpl(Args... args) {
static const auto workspaceSizeFuncAddr = getCommPcclFuncAddr(workspaceApi);
using WorkspaceSizeFunc = int (*)(Args...);
static WorkspaceSizeFunc workspaceSizeFunc = reinterpret_cast<WorkspaceSizeFunc>(workspaceSizeFuncAddr);
auto workspaceStatus = workspaceSizeFunc(args...);
if (workspaceStatus != pcclSuccess) {
throw std::runtime_error(
std::string("[") + workspaceApi + "]'s return value is not equal to PCCL_SUCCESS. pcclStatus is " + std::to_string(workspaceStatus)
);
}
template <const char* PcclFuncName,typename ReturnType, typename... Args>
ReturnType callPcclImpl(Args... args) {
static const auto pcclFuncAddr = getCommPcclFuncAddr(PcclFuncName);
using dipuPcclFunc = ReturnType (*)(Args...);
static dipuPcclFunc pcclFunc = reinterpret_cast<dipuPcclFunc>(pcclFuncAddr);
auto pcclCallReturn = pcclFunc(args...);
return pcclCallReturn;
}

#define DIPU_PCCL_IMPL(NAME, ...) \
pcclResult_t NAME(DIPU_TYPE_PARAM(__VA_ARGS__)) { \
#define DIPU_PCCL_IMPL(NAME, RETURN, ...) \
RETURN NAME(DIPU_TYPE_PARAM(__VA_ARGS__)) { \
static constexpr const char fstr[] = #NAME; \
callPcclImpl<fstr>(DIPU_PARAM(__VA_ARGS__)); \
return pcclSuccess; \
return callPcclImpl<fstr, RETURN>(DIPU_PARAM(__VA_ARGS__)); \
} \
static pcclResult_t CONCAT(my__, NAME)(DIPU_TYPE_PARAM(__VA_ARGS__)); \
static RETURN CONCAT(my__, NAME)(DIPU_TYPE_PARAM(__VA_ARGS__)); \
static const int CONCAT(n_, NAME) = []() { \
fn[#NAME] = reinterpret_cast<void*>(CONCAT(my__, NAME)); \
return 0; \
}(); \
pcclResult_t CONCAT(my__, NAME)(DIPU_TYPE_PARAM(__VA_ARGS__))
RETURN CONCAT(my__, NAME)(DIPU_TYPE_PARAM(__VA_ARGS__))

#define DIPU_PCCL_COMM_IMPL(NAME, ...) DIPU_PCCL_IMPL(NAME, pcclResult_t, __VA_ARGS__)
#define DIPU_PCCL_ERROR_IMPL(NAME, ...) DIPU_PCCL_IMPL(NAME, const char*, __VA_ARGS__)

std::map<std::string, void*> fn;

Expand Down Expand Up @@ -100,11 +99,11 @@ void singleDeviceMemcpy(dipu::deviceStream_t stream, void* dst, const void* src,



DIPU_PCCL_IMPL(pcclGetUniqueId, (pcclUniqueId*, uniqueId)) {
DIPU_PCCL_COMM_IMPL(pcclGetUniqueId, (pcclUniqueId*, uniqueId)) {
return pcclSuccess;
}

DIPU_PCCL_IMPL(pcclCommInitRank, (pcclComm_t*, comm), (int, ndev), (pcclUniqueId, commIdI), (int, rank)) {
DIPU_PCCL_COMM_IMPL(pcclCommInitRank, (pcclComm_t*, comm), (int, ndev), (pcclUniqueId, commIdI), (int, rank)) {
checkNrankOrThrow(ndev);
checkRankOrThrow(rank);
DIPU_LOGW(
Expand All @@ -114,72 +113,63 @@ void singleDeviceMemcpy(dipu::deviceStream_t stream, void* dst, const void* src,
return pcclSuccess;
}

DIPU_PCCL_IMPL(pcclCommDestroy, (pcclComm_t, comm)) {
DIPU_PCCL_COMM_IMPL(pcclCommDestroy, (pcclComm_t, comm)) {
checkCommOrThrow(comm);
// destroyDiclComm(comm);
return pcclSuccess;
}

DIPU_PCCL_IMPL(pcclCommGetAsyncError, (pcclComm_t, comm), (pcclResult_t*, asyncError)) {
DIPU_PCCL_COMM_IMPL(pcclCommGetAsyncError, (pcclComm_t, comm), (pcclResult_t*, asyncError)) {
checkCommOrThrow(comm);
return pcclSuccess;
}

const char* pcclGetErrorString(pcclResult_t result){
// Not Fallback
static const char* apiName = "pcclGetErrorString";
static const auto funcptr = getCommPcclFuncAddr(apiName);
using func = const char*(*)(pcclResult_t);
return reinterpret_cast<func>(funcptr)(result);
}

const char* pcclGetLastError(pcclComm_t comm){
// Not Fallback
static const char* apiName = "pcclGetLastError";
static const auto funcptr = getCommPcclFuncAddr(apiName);
using func = const char*(*)(pcclComm_t);
return reinterpret_cast<func>(funcptr)(comm);
}
DIPU_PCCL_ERROR_IMPL(pcclGetErrorString, (pcclResult_t, result)){
throw std::runtime_error("Fallback pccl impl should not call pcclGetErrorString");
}

DIPU_PCCL_ERROR_IMPL(pcclGetLastError, (pcclComm_t, comm)){
throw std::runtime_error("Fallback pccl impl should not call pcclGetLastError");
}

DIPU_PCCL_IMPL(pcclReduce, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclRedOp_t, op), (int, root), (pcclComm_t, comm), (tangStream_t, stream)) {
DIPU_PCCL_COMM_IMPL(pcclReduce, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclRedOp_t, op), (int, root), (pcclComm_t, comm), (tangStream_t, stream)) {
checkCommOrThrow(comm);
checkRankOrThrow(root);
singleDeviceMemcpy(stream, recvbuff, sendbuff,
count * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}

DIPU_PCCL_IMPL(pcclAllReduce, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclRedOp_t, op), (pcclComm_t, comm), (tangStream_t, stream)) {
DIPU_PCCL_COMM_IMPL(pcclAllReduce, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclRedOp_t, op), (pcclComm_t, comm), (tangStream_t, stream)) {
checkCommOrThrow(comm);
singleDeviceMemcpy(stream, recvbuff, sendbuff,
count * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}

DIPU_PCCL_IMPL(pcclReduceScatter, (const void*, sendbuff), (void*, recvbuff), (size_t, recvcount), (pcclDataType_t, datatype), (pcclRedOp_t, op), (pcclComm_t, comm), (tangStream_t, stream)) {
DIPU_PCCL_COMM_IMPL(pcclReduceScatter, (const void*, sendbuff), (void*, recvbuff), (size_t, recvcount), (pcclDataType_t, datatype), (pcclRedOp_t, op), (pcclComm_t, comm), (tangStream_t, stream)) {
singleDeviceMemcpy(stream, recvbuff, sendbuff,
recvcount * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}

DIPU_PCCL_IMPL(pcclBroadcast, (const void *, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (int, root), (pcclComm_t, comm), (tangStream_t, stream)) {
DIPU_PCCL_COMM_IMPL(pcclBroadcast, (const void *, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (int, root), (pcclComm_t, comm), (tangStream_t, stream)) {
checkCommOrThrow(comm);
singleDeviceMemcpy(stream, recvbuff, sendbuff,
count * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}
DIPU_PCCL_IMPL(pcclAllGather, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclComm_t, comm), (tangStream_t, stream)) {
DIPU_PCCL_COMM_IMPL(pcclAllGather, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclComm_t, comm), (tangStream_t, stream)) {
checkCommOrThrow(comm);
singleDeviceMemcpy(stream, recvbuff, sendbuff,
count * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}
DIPU_PCCL_IMPL(pcclSend, (const void*, sendbuff), (size_t, count), (pcclDataType_t, datatype), (int, peer), (pcclComm_t, comm), (tangStream_t, stream)) {
DIPU_PCCL_COMM_IMPL(pcclSend, (const void*, sendbuff), (size_t, count), (pcclDataType_t, datatype), (int, peer), (pcclComm_t, comm), (tangStream_t, stream)) {
throwNotSupportedError();
return pcclInvalidUsage;
}
DIPU_PCCL_IMPL(pcclRecv, (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (int, peer), (pcclComm_t, comm), (tangStream_t, stream)) {
DIPU_PCCL_COMM_IMPL(pcclRecv, (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (int, peer), (pcclComm_t, comm), (tangStream_t, stream)) {
throwNotSupportedError();
return pcclInvalidUsage;
}
Expand Down
4 changes: 1 addition & 3 deletions dipu/torch_dipu/csrc_dipu/vendor/droplet/vendorapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

#include <c10/util/Exception.h>
#include "csrc_dipu/vendor/droplet/pccl.h"
#ifdef USE_PCCL
#include <pccl.h>
#endif // USE_PCCL
#include "pccl.h"
#include <tang_runtime.h>

#include <csrc_dipu/common.h>
Expand Down

0 comments on commit 608ca55

Please sign in to comment.