Skip to content

Commit

Permalink
and file comment& add tang_shared_so load & clang format
Browse files Browse the repository at this point in the history
  • Loading branch information
Gong-air committed Sep 9, 2024
1 parent 608ca55 commit 6e8d69c
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
#include <string>
#include <type_traits>

#include "pccl.h"

#include <c10/core/ScalarType.h>
#include <torch/csrc/distributed/c10d/Types.hpp>

#include "csrc_dipu/runtime/device/basedef.h"
#include "csrc_dipu/runtime/devproxy/deviceproxy.h"
#include <csrc_dipu/common.h>
#include <csrc_dipu/runtime/device/diclapis.h>

#include "vendorapi.h"
#include "pccl.h"

namespace dipu {

Expand Down
237 changes: 137 additions & 100 deletions dipu/torch_dipu/csrc_dipu/vendor/droplet/pccl.cpp
Original file line number Diff line number Diff line change
@@ -1,45 +1,59 @@
/**
* pccl.cpp
*
* Description:
* This file implements the dynamic loading and invocation of PCCL APIs required
* by DICL. If the pccllib.so library is not found, a log message will be
* printed, and a Fallback API will be executed.
*
* Notes:
* - We have copied the PCCL header file. If the PCCL header file are updated,
* please correspondingly update them here.
*/
#include "pccl.h"

#include <cstddef>
#include <stdexcept>

#include "pcclcommon.h"

#include <c10/core/ScalarType.h>
#include <torch/csrc/distributed/c10d/Types.hpp>

#include "csrc_dipu/runtime/device/basedef.h"
#include "csrc_dipu/runtime/devproxy/deviceproxy.h"
#include <torch/csrc/distributed/c10d/Types.hpp>

#include <csrc_dipu/common.h>
#include <csrc_dipu/runtime/device/diclapis.h>


namespace {
template <const char* PcclFuncName,typename ReturnType, typename... Args>
template <const char* PcclFuncName, typename ReturnType, typename... Args>
ReturnType callPcclImpl(Args... args) {
static const auto pcclFuncAddr = getCommPcclFuncAddr(PcclFuncName);
using dipuPcclFunc = ReturnType (*)(Args...);
static dipuPcclFunc pcclFunc = reinterpret_cast<dipuPcclFunc>(pcclFuncAddr);
auto pcclCallReturn = pcclFunc(args...);
return pcclCallReturn;
}

#define DIPU_PCCL_IMPL(NAME, RETURN, ...) \
RETURN NAME(DIPU_TYPE_PARAM(__VA_ARGS__)) { \
static constexpr const char fstr[] = #NAME; \
return callPcclImpl<fstr, RETURN>(DIPU_PARAM(__VA_ARGS__)); \
} \
static const auto pcclFuncAddr = getCommPcclFuncAddr(PcclFuncName);
using dipuPcclFunc = ReturnType (*)(Args...);
static dipuPcclFunc pcclFunc = reinterpret_cast<dipuPcclFunc>(pcclFuncAddr);
auto pcclCallReturn = pcclFunc(args...);
return pcclCallReturn;
}

#define DIPU_PCCL_IMPL(NAME, RETURN, ...) \
RETURN NAME(DIPU_TYPE_PARAM(__VA_ARGS__)) { \
static constexpr const char fstr[] = #NAME; \
return callPcclImpl<fstr, RETURN>(DIPU_PARAM(__VA_ARGS__)); \
} \
static RETURN CONCAT(my__, NAME)(DIPU_TYPE_PARAM(__VA_ARGS__)); \
static const int CONCAT(n_, NAME) = []() { \
fn[#NAME] = reinterpret_cast<void*>(CONCAT(my__, NAME)); \
return 0; \
}(); \
static const int CONCAT(n_, NAME) = []() { \
fn[#NAME] = reinterpret_cast<void*>(CONCAT(my__, NAME)); \
return 0; \
}(); \
RETURN CONCAT(my__, NAME)(DIPU_TYPE_PARAM(__VA_ARGS__))

#define DIPU_PCCL_COMM_IMPL(NAME, ...) DIPU_PCCL_IMPL(NAME, pcclResult_t, __VA_ARGS__)
#define DIPU_PCCL_ERROR_IMPL(NAME, ...) DIPU_PCCL_IMPL(NAME, const char*, __VA_ARGS__)
#define DIPU_PCCL_COMM_IMPL(NAME, ...) \
DIPU_PCCL_IMPL(NAME, pcclResult_t, __VA_ARGS__)
#define DIPU_PCCL_ERROR_IMPL(NAME, ...) \
DIPU_PCCL_IMPL(NAME, const char*, __VA_ARGS__)

std::map<std::string, void*> fn;


static const std::map<pcclDataType_t, at::ScalarType> toScalarType = {
{pcclInt8, at::kChar},
{pcclUint8, at::kByte},
Expand All @@ -54,17 +68,18 @@ static const std::map<pcclDataType_t, at::ScalarType> toScalarType = {
};

at::ScalarType PcclDataTypeToScalarType(pcclDataType_t pccl_data_type) {
auto p = toScalarType.find(pccl_data_type);
if (p == toScalarType.end()) {
throw std::runtime_error("Not supported pcclDataType_t: " + std::to_string(pccl_data_type));
}
return p->second;
auto p = toScalarType.find(pccl_data_type);
if (p == toScalarType.end()) {
throw std::runtime_error("Not supported pcclDataType_t: " +
std::to_string(pccl_data_type));
}
return p->second;
}

static const pcclComm_t kMagicComm = reinterpret_cast<pcclComm_t>(0x5043434C);

void checkCommOrThrow(pcclComm_t comm) {
if (comm == nullptr || comm!=kMagicComm) {
if (comm == nullptr || comm != kMagicComm) {
throw std::runtime_error("Invalid comm.");
}
}
Expand All @@ -88,89 +103,111 @@ void checkRankOrThrow(int rank) {

void singleDeviceMemcpy(dipu::deviceStream_t stream, void* dst, const void* src,
size_t nbytes) {
if(dst != src){
auto device = dipu::devproxy::current_device();
dipu::devproxy::memCopyD2DAsync(stream, nbytes, device, dst, device, src);
if (dst != src) {
auto device = dipu::devproxy::current_device();
dipu::devproxy::memCopyD2DAsync(stream, nbytes, device, dst, device, src);
}
}

} // namespace

DIPU_PCCL_COMM_IMPL(pcclGetUniqueId, (pcclUniqueId*, uniqueId)) {
return pcclSuccess;
}

DIPU_PCCL_COMM_IMPL(pcclCommInitRank, (pcclComm_t*, comm), (int, ndev),
(pcclUniqueId, commIdI), (int, rank)) {
checkNrankOrThrow(ndev);
checkRankOrThrow(rank);
DIPU_LOGW(
"PCCL is not enabled. DIPU will simulate single GPU "
"communication using memcpy.");
*comm = kMagicComm;
return pcclSuccess;
}

DIPU_PCCL_COMM_IMPL(pcclCommDestroy, (pcclComm_t, comm)) {
checkCommOrThrow(comm);
// destroyDiclComm(comm);
return pcclSuccess;
}

DIPU_PCCL_COMM_IMPL(pcclGetUniqueId, (pcclUniqueId*, uniqueId)) {
return pcclSuccess;
}

DIPU_PCCL_COMM_IMPL(pcclCommInitRank, (pcclComm_t*, comm), (int, ndev), (pcclUniqueId, commIdI), (int, rank)) {
checkNrankOrThrow(ndev);
checkRankOrThrow(rank);
DIPU_LOGW(
"PCCL is not enabled. DIPU will simulate single GPU "
"communication using memcpy.");
*comm = kMagicComm;
return pcclSuccess;
}

DIPU_PCCL_COMM_IMPL(pcclCommDestroy, (pcclComm_t, comm)) {
checkCommOrThrow(comm);
// destroyDiclComm(comm);
return pcclSuccess;
}

DIPU_PCCL_COMM_IMPL(pcclCommGetAsyncError, (pcclComm_t, comm), (pcclResult_t*, asyncError)) {
checkCommOrThrow(comm);
return pcclSuccess;
}

DIPU_PCCL_ERROR_IMPL(pcclGetErrorString, (pcclResult_t, result)){
throw std::runtime_error("Fallback pccl impl should not call pcclGetErrorString");
}
DIPU_PCCL_COMM_IMPL(pcclCommGetAsyncError, (pcclComm_t, comm),
(pcclResult_t*, asyncError)) {
checkCommOrThrow(comm);
return pcclSuccess;
}

DIPU_PCCL_ERROR_IMPL(pcclGetLastError, (pcclComm_t, comm)){
throw std::runtime_error("Fallback pccl impl should not call pcclGetLastError");
}
DIPU_PCCL_ERROR_IMPL(pcclGetErrorString, (pcclResult_t, result)) {
throw std::runtime_error(
"Fallback pccl impl should not call pcclGetErrorString");
}

DIPU_PCCL_COMM_IMPL(pcclReduce, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclRedOp_t, op), (int, root), (pcclComm_t, comm), (tangStream_t, stream)) {
checkCommOrThrow(comm);
checkRankOrThrow(root);
singleDeviceMemcpy(stream, recvbuff, sendbuff,
count * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}
DIPU_PCCL_ERROR_IMPL(pcclGetLastError, (pcclComm_t, comm)) {
throw std::runtime_error(
"Fallback pccl impl should not call pcclGetLastError");
}

DIPU_PCCL_COMM_IMPL(pcclAllReduce, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclRedOp_t, op), (pcclComm_t, comm), (tangStream_t, stream)) {
checkCommOrThrow(comm);
singleDeviceMemcpy(stream, recvbuff, sendbuff,
count * at::elementSize(PcclDataTypeToScalarType(datatype)));
DIPU_PCCL_COMM_IMPL(pcclReduce, (const void*, sendbuff), (void*, recvbuff),
(size_t, count), (pcclDataType_t, datatype),
(pcclRedOp_t, op), (int, root), (pcclComm_t, comm),
(tangStream_t, stream)) {
checkCommOrThrow(comm);
checkRankOrThrow(root);
singleDeviceMemcpy(
stream, recvbuff, sendbuff,
count * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}
}

DIPU_PCCL_COMM_IMPL(pcclReduceScatter, (const void*, sendbuff), (void*, recvbuff), (size_t, recvcount), (pcclDataType_t, datatype), (pcclRedOp_t, op), (pcclComm_t, comm), (tangStream_t, stream)) {
singleDeviceMemcpy(stream, recvbuff, sendbuff,
recvcount * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}
DIPU_PCCL_COMM_IMPL(pcclAllReduce, (const void*, sendbuff), (void*, recvbuff),
(size_t, count), (pcclDataType_t, datatype),
(pcclRedOp_t, op), (pcclComm_t, comm),
(tangStream_t, stream)) {
checkCommOrThrow(comm);
singleDeviceMemcpy(
stream, recvbuff, sendbuff,
count * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}

DIPU_PCCL_COMM_IMPL(pcclBroadcast, (const void *, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (int, root), (pcclComm_t, comm), (tangStream_t, stream)) {
checkCommOrThrow(comm);
singleDeviceMemcpy(stream, recvbuff, sendbuff,
count * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}
DIPU_PCCL_COMM_IMPL(pcclAllGather, (const void*, sendbuff), (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (pcclComm_t, comm), (tangStream_t, stream)) {
checkCommOrThrow(comm);
singleDeviceMemcpy(stream, recvbuff, sendbuff,
count * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}
DIPU_PCCL_COMM_IMPL(pcclSend, (const void*, sendbuff), (size_t, count), (pcclDataType_t, datatype), (int, peer), (pcclComm_t, comm), (tangStream_t, stream)) {
throwNotSupportedError();
return pcclInvalidUsage;
}
DIPU_PCCL_COMM_IMPL(pcclRecv, (void*, recvbuff), (size_t, count), (pcclDataType_t, datatype), (int, peer), (pcclComm_t, comm), (tangStream_t, stream)) {
throwNotSupportedError();
return pcclInvalidUsage;
}
DIPU_PCCL_COMM_IMPL(pcclReduceScatter, (const void*, sendbuff),
(void*, recvbuff), (size_t, recvcount),
(pcclDataType_t, datatype), (pcclRedOp_t, op),
(pcclComm_t, comm), (tangStream_t, stream)) {
singleDeviceMemcpy(
stream, recvbuff, sendbuff,
recvcount * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}

DIPU_PCCL_COMM_IMPL(pcclBroadcast, (const void*, sendbuff), (void*, recvbuff),
(size_t, count), (pcclDataType_t, datatype), (int, root),
(pcclComm_t, comm), (tangStream_t, stream)) {
checkCommOrThrow(comm);
singleDeviceMemcpy(
stream, recvbuff, sendbuff,
count * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}
DIPU_PCCL_COMM_IMPL(pcclAllGather, (const void*, sendbuff), (void*, recvbuff),
(size_t, count), (pcclDataType_t, datatype),
(pcclComm_t, comm), (tangStream_t, stream)) {
checkCommOrThrow(comm);
singleDeviceMemcpy(
stream, recvbuff, sendbuff,
count * at::elementSize(PcclDataTypeToScalarType(datatype)));
return pcclSuccess;
}
DIPU_PCCL_COMM_IMPL(pcclSend, (const void*, sendbuff), (size_t, count),
(pcclDataType_t, datatype), (int, peer), (pcclComm_t, comm),
(tangStream_t, stream)) {
throwNotSupportedError();
return pcclInvalidUsage;
}
DIPU_PCCL_COMM_IMPL(pcclRecv, (void*, recvbuff), (size_t, count),
(pcclDataType_t, datatype), (int, peer), (pcclComm_t, comm),
(tangStream_t, stream)) {
throwNotSupportedError();
return pcclInvalidUsage;
}
18 changes: 14 additions & 4 deletions dipu/torch_dipu/csrc_dipu/vendor/droplet/pcclcommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,22 @@ inline void* getCommPcclLibHandler(const char* libName) {
}

inline void* getCommPcclFuncAddr(const char* apiName) {
constexpr const char kOpApiLibName[] = "libpccl.so";
static void* opApiHandler = getCommPcclLibHandler(kOpApiLibName);
if (opApiHandler == nullptr) {
constexpr const char pcclLibName[] = "libpccl.so";
constexpr const char pcclLibDependName[] = "libtangrt_shared.so";
static void* pcclDependHandler = dlopen(pcclLibDependName, RTLD_LAZY | RTLD_GLOBAL);
if(pcclDependHandler == nullptr){
throw std::runtime_error(
"Error: Failed to load libpccl.so. The required library 'libtangrt_shared.so' is missing.\n"
"Please ensure that 'libtangrt_shared.so' is installed and its path is included in the LD_LIBRARY_PATH environment variable.\n"
"Example: export LD_LIBRARY_PATH=/path/to/lib:$LD_LIBRARY_PATH"
);
}
static void* pcclHandler = getCommPcclLibHandler(pcclLibName);
if (pcclHandler == nullptr) {
std::cerr << "Fallback " << apiName << " will be called" << std::endl;
return nullptr;
}
return getCommPcclFuncAddrInLib(opApiHandler, kOpApiLibName, apiName);
return getCommPcclFuncAddrInLib(pcclHandler, pcclLibName, apiName);
}

#define EXPAND(x) x
Expand Down

0 comments on commit 6e8d69c

Please sign in to comment.