diff --git a/dipu/torch_dipu/csrc_dipu/vendor/droplet/pccl.cpp b/dipu/torch_dipu/csrc_dipu/vendor/droplet/pccl.cpp index d54829ce0..d5289ff25 100644 --- a/dipu/torch_dipu/csrc_dipu/vendor/droplet/pccl.cpp +++ b/dipu/torch_dipu/csrc_dipu/vendor/droplet/pccl.cpp @@ -1,5 +1,6 @@ #include "pccl.h" #include +#include #include "pcclcommon.h" #include @@ -12,31 +13,29 @@ namespace { -template -void callPcclImpl(Args... args) { - static const auto workspaceSizeFuncAddr = getCommPcclFuncAddr(workspaceApi); - using WorkspaceSizeFunc = int (*)(Args...); - static WorkspaceSizeFunc workspaceSizeFunc = reinterpret_cast(workspaceSizeFuncAddr); - auto workspaceStatus = workspaceSizeFunc(args...); - if (workspaceStatus != pcclSuccess) { - throw std::runtime_error( - std::string("[") + workspaceApi + "]'s return value is not equal to PCCL_SUCCESS. pcclStatus is " + std::to_string(workspaceStatus) - ); - } +template +ReturnType callPcclImpl(Args... args) { + static const auto pcclFuncAddr = getCommPcclFuncAddr(PcclFuncName); + using dipuPcclFunc = ReturnType (*)(Args...); + static dipuPcclFunc pcclFunc = reinterpret_cast(pcclFuncAddr); + auto pcclCallReturn = pcclFunc(args...); + return pcclCallReturn; } -#define DIPU_PCCL_IMPL(NAME, ...) \ - pcclResult_t NAME(DIPU_TYPE_PARAM(__VA_ARGS__)) { \ +#define DIPU_PCCL_IMPL(NAME, RETURN, ...) \ + RETURN NAME(DIPU_TYPE_PARAM(__VA_ARGS__)) { \ static constexpr const char fstr[] = #NAME; \ - callPcclImpl(DIPU_PARAM(__VA_ARGS__)); \ - return pcclSuccess; \ + return callPcclImpl(DIPU_PARAM(__VA_ARGS__)); \ } \ - static pcclResult_t CONCAT(my__, NAME)(DIPU_TYPE_PARAM(__VA_ARGS__)); \ + static RETURN CONCAT(my__, NAME)(DIPU_TYPE_PARAM(__VA_ARGS__)); \ static const int CONCAT(n_, NAME) = []() { \ fn[#NAME] = reinterpret_cast(CONCAT(my__, NAME)); \ return 0; \ }(); \ - pcclResult_t CONCAT(my__, NAME)(DIPU_TYPE_PARAM(__VA_ARGS__)) + RETURN CONCAT(my__, NAME)(DIPU_TYPE_PARAM(__VA_ARGS__)) + +#define DIPU_PCCL_COMM_IMPL(NAME, ...) DIPU_PCCL_IMPL(NAME, pcclResult_t, __VA_ARGS__) +#define DIPU_PCCL_ERROR_IMPL(NAME, ...) DIPU_PCCL_IMPL(NAME, const char*, __VA_ARGS__) std::map fn; @@ -100,11 +99,11 @@ void singleDeviceMemcpy(dipu::deviceStream_t stream, void* dst, const void* src, - DIPU_PCCL_IMPL(pcclGetUniqueId, (pcclUniqueId*, uniqueId)) { + DIPU_PCCL_COMM_IMPL(pcclGetUniqueId, (pcclUniqueId*, uniqueId)) { return pcclSuccess; } - DIPU_PCCL_IMPL(pcclCommInitRank, (pcclComm_t*, comm), (int, ndev), (pcclUniqueId, commIdI), (int, rank)) { + DIPU_PCCL_COMM_IMPL(pcclCommInitRank, (pcclComm_t*, comm), (int, ndev), (pcclUniqueId, commIdI), (int, rank)) { checkNrankOrThrow(ndev); checkRankOrThrow(rank); DIPU_LOGW( @@ -114,35 +113,26 @@ void singleDeviceMemcpy(dipu::deviceStream_t stream, void* dst, const void* src, return pcclSuccess; } - DIPU_PCCL_IMPL(pcclCommDestroy, (pcclComm_t, comm)) { + DIPU_PCCL_COMM_IMPL(pcclCommDestroy, (pcclComm_t, comm)) { checkCommOrThrow(comm); // destroyDiclComm(comm); return pcclSuccess; } - DIPU_PCCL_IMPL(pcclCommGetAsyncError, (pcclComm_t, comm), (pcclResult_t*, asyncError)) { + DIPU_PCCL_COMM_IMPL(pcclCommGetAsyncError, (pcclComm_t, comm), (pcclResult_t*, asyncError)) { checkCommOrThrow(comm); return pcclSuccess; } -const char* pcclGetErrorString(pcclResult_t result){ - // Not Fallback - static const char* apiName = "pcclGetErrorString"; - static const auto funcptr = getCommPcclFuncAddr(apiName); - using func = const char*(*)(pcclResult_t); - return reinterpret_cast(funcptr)(result); -} - -const char* pcclGetLastError(pcclComm_t comm){ - // Not Fallback - static const char* apiName = "pcclGetLastError"; - static const auto funcptr = getCommPcclFuncAddr(apiName); - using func = const char*(*)(pcclComm_t); - return reinterpret_cast(funcptr)(comm); -} + DIPU_PCCL_ERROR_IMPL(pcclGetErrorString, (pcclResult_t, result)){ + throw std::runtime_error("Fallback pccl impl should not call pcclGetErrorString"); + } + DIPU_PCCL_ERROR_IMPL(pcclGetLastError, (pcclComm_t, comm)){ + throw std::runtime_error("Fallback pccl impl should not call pcclGetLastError"); + } - DIPU_PCCL_IMPL(pcclReduce, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclRedOp_t, op), (int, root), (pcclComm_t, comm), (tangStream_t, stream)) { + DIPU_PCCL_COMM_IMPL(pcclReduce, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclRedOp_t, op), (int, root), (pcclComm_t, comm), (tangStream_t, stream)) { checkCommOrThrow(comm); checkRankOrThrow(root); singleDeviceMemcpy(stream, recvbuff, sendbuff, @@ -150,36 +140,36 @@ const char* pcclGetLastError(pcclComm_t comm){ return pcclSuccess; } - DIPU_PCCL_IMPL(pcclAllReduce, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclRedOp_t, op), (pcclComm_t, comm), (tangStream_t, stream)) { + DIPU_PCCL_COMM_IMPL(pcclAllReduce, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclRedOp_t, op), (pcclComm_t, comm), (tangStream_t, stream)) { checkCommOrThrow(comm); singleDeviceMemcpy(stream, recvbuff, sendbuff, count * at::elementSize(PcclDataTypeToScalarType(datatype))); return pcclSuccess; } - DIPU_PCCL_IMPL(pcclReduceScatter, (const void*, sendbuff), (void*, recvbuff), (size_t, recvcount), (pcclDataType_t, datatype), (pcclRedOp_t, op), (pcclComm_t, comm), (tangStream_t, stream)) { + DIPU_PCCL_COMM_IMPL(pcclReduceScatter, (const void*, sendbuff), (void*, recvbuff), (size_t, recvcount), (pcclDataType_t, datatype), (pcclRedOp_t, op), (pcclComm_t, comm), (tangStream_t, stream)) { singleDeviceMemcpy(stream, recvbuff, sendbuff, recvcount * at::elementSize(PcclDataTypeToScalarType(datatype))); return pcclSuccess; } - DIPU_PCCL_IMPL(pcclBroadcast, (const void *, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (int, root), (pcclComm_t, comm), (tangStream_t, stream)) { + DIPU_PCCL_COMM_IMPL(pcclBroadcast, (const void *, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (int, root), (pcclComm_t, comm), (tangStream_t, stream)) { checkCommOrThrow(comm); singleDeviceMemcpy(stream, recvbuff, sendbuff, count * at::elementSize(PcclDataTypeToScalarType(datatype))); return pcclSuccess; } - DIPU_PCCL_IMPL(pcclAllGather, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclComm_t, comm), (tangStream_t, stream)) { + DIPU_PCCL_COMM_IMPL(pcclAllGather, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclComm_t, comm), (tangStream_t, stream)) { checkCommOrThrow(comm); singleDeviceMemcpy(stream, recvbuff, sendbuff, count * at::elementSize(PcclDataTypeToScalarType(datatype))); return pcclSuccess; } - DIPU_PCCL_IMPL(pcclSend, (const void*, sendbuff), (size_t, count), (pcclDataType_t, datatype), (int, peer), (pcclComm_t, comm), (tangStream_t, stream)) { + DIPU_PCCL_COMM_IMPL(pcclSend, (const void*, sendbuff), (size_t, count), (pcclDataType_t, datatype), (int, peer), (pcclComm_t, comm), (tangStream_t, stream)) { throwNotSupportedError(); return pcclInvalidUsage; } - DIPU_PCCL_IMPL(pcclRecv, (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (int, peer), (pcclComm_t, comm), (tangStream_t, stream)) { + DIPU_PCCL_COMM_IMPL(pcclRecv, (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (int, peer), (pcclComm_t, comm), (tangStream_t, stream)) { throwNotSupportedError(); return pcclInvalidUsage; } diff --git a/dipu/torch_dipu/csrc_dipu/vendor/droplet/vendorapi.h b/dipu/torch_dipu/csrc_dipu/vendor/droplet/vendorapi.h index 6aa4d82ab..680c204a2 100644 --- a/dipu/torch_dipu/csrc_dipu/vendor/droplet/vendorapi.h +++ b/dipu/torch_dipu/csrc_dipu/vendor/droplet/vendorapi.h @@ -3,9 +3,7 @@ #include #include "csrc_dipu/vendor/droplet/pccl.h" -#ifdef USE_PCCL -#include -#endif // USE_PCCL +#include "pccl.h" #include #include