From 2ab6d076c21cf8ff8c66ac17a1a8fa542a7bb5c9 Mon Sep 17 00:00:00 2001 From: Devendar Bureddy Date: Mon, 15 Jan 2024 09:13:11 -0800 Subject: [PATCH] TL/SHARP: Add reduce-scatter support (#891) --- .github/workflows/clang-tidy-nvidia.yaml | 6 +- config/m4/sharp.m4 | 1 + src/components/tl/sharp/tl_sharp.h | 13 +++- src/components/tl/sharp/tl_sharp_coll.c | 94 ++++++++++++++++++++++++ src/components/tl/sharp/tl_sharp_coll.h | 3 + src/components/tl/sharp/tl_sharp_team.c | 5 ++ 6 files changed, 118 insertions(+), 4 deletions(-) diff --git a/.github/workflows/clang-tidy-nvidia.yaml b/.github/workflows/clang-tidy-nvidia.yaml index 3609a0a7a1..ae2cde7580 100644 --- a/.github/workflows/clang-tidy-nvidia.yaml +++ b/.github/workflows/clang-tidy-nvidia.yaml @@ -5,7 +5,7 @@ on: [push, pull_request] env: OPEN_UCX_LINK: https://github.com/openucx/ucx OPEN_UCX_BRANCH: master - HPCX_LINK: http://content.mellanox.com/hpc/hpc-x/v2.13/hpcx-v2.13-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl2.12-x86_64.tbz + HPCX_LINK: https://content.mellanox.com/hpc/hpc-x/v2.17.1rc2/hpcx-v2.17.1-gcc-mlnx_ofed-ubuntu20.04-cuda12-x86_64.tbz CLANG_VER: 12 MLNX_OFED_VER: 5.9-0.5.6.0 CUDA_VER: 11-4 @@ -45,8 +45,8 @@ jobs: run: | cd /tmp wget ${HPCX_LINK} - tar xjf hpcx-v2.13-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl2.12-x86_64.tbz - mv hpcx-v2.13-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl2.12-x86_64 hpcx + tar xjf hpcx-v2.17.1-gcc-mlnx_ofed-ubuntu20.04-cuda12-x86_64.tbz + mv hpcx-v2.17.1-gcc-mlnx_ofed-ubuntu20.04-cuda12-x86_64 hpcx - uses: actions/checkout@v1 - name: Build UCC run: | diff --git a/config/m4/sharp.m4 b/config/m4/sharp.m4 index bedc550476..45bcfd04e3 100644 --- a/config/m4/sharp.m4 +++ b/config/m4/sharp.m4 @@ -44,6 +44,7 @@ AS_IF([test "x$with_sharp" != "xno"], AC_SUBST(SHARP_LDFLAGS, "-lsharp_coll -L$check_sharp_dir/lib") AC_CHECK_DECLS([SHARP_COLL_HIDE_ERRORS], [], [], [[#include ]]) AC_CHECK_DECLS([SHARP_COLL_DISABLE_LAZY_GROUP_RESOURCE_ALLOC], [], [], [[#include ]]) + AC_CHECK_DECLS([sharp_coll_do_reduce_scatter], [], [], [[#include ]]) ], [ AS_IF([test "x$with_sharp" != "xguess"], diff --git a/src/components/tl/sharp/tl_sharp.h b/src/components/tl/sharp/tl_sharp.h index cc44e9e1f4..adfbc86036 100644 --- a/src/components/tl/sharp/tl_sharp.h +++ b/src/components/tl/sharp/tl_sharp.h @@ -108,6 +108,10 @@ typedef struct ucc_tl_sharp_task { ucc_tl_sharp_reg_t *s_mem_h; ucc_tl_sharp_reg_t *r_mem_h; } allreduce; + struct { + ucc_tl_sharp_reg_t *s_mem_h; + ucc_tl_sharp_reg_t *r_mem_h; + } reduce_scatter; struct { ucc_tl_sharp_reg_t *mem_h; } bcast; @@ -131,9 +135,16 @@ ucc_status_t sharp_status_to_ucc_status(int status); (ucc_derived_of((_task)->super.team->context->lib, ucc_tl_sharp_lib_t)) #define TASK_ARGS(_task) (_task)->super.bargs.args -#define UCC_TL_SHARP_SUPPORTED_COLLS \ +#define UCC_TL_BASIC_SHARP_SUPPORTED_COLLS \ (UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST) +#if HAVE_DECL_SHARP_COLL_DO_REDUCE_SCATTER +#define UCC_TL_SHARP_SUPPORTED_COLLS \ + (UCC_TL_BASIC_SHARP_SUPPORTED_COLLS | UCC_COLL_TYPE_REDUCE_SCATTER) +#else +#define UCC_TL_SHARP_SUPPORTED_COLLS (UCC_TL_BASIC_SHARP_SUPPORTED_COLLS) +#endif + UCC_CLASS_DECLARE(ucc_tl_sharp_team_t, ucc_base_context_t *, const ucc_base_team_params_t *); diff --git a/src/components/tl/sharp/tl_sharp_coll.c b/src/components/tl/sharp/tl_sharp_coll.c index 1dcf2465c1..5884e18918 100644 --- a/src/components/tl/sharp/tl_sharp_coll.c +++ b/src/components/tl/sharp/tl_sharp_coll.c @@ -308,6 +308,100 @@ ucc_status_t ucc_tl_sharp_bcast_start(ucc_coll_task_t *coll_task) return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); } +#if HAVE_DECL_SHARP_COLL_DO_REDUCE_SCATTER +ucc_status_t ucc_tl_sharp_reduce_scatter_start(ucc_coll_task_t *coll_task) +{ + ucc_tl_sharp_task_t *task = ucc_derived_of(coll_task, ucc_tl_sharp_task_t); + ucc_tl_sharp_team_t *team = TASK_TEAM(task); + ucc_coll_args_t *args = &TASK_ARGS(task); + size_t count = args->dst.info.count; + ucc_datatype_t dt = args->dst.info.datatype; + struct sharp_coll_reduce_spec reduce_spec; + enum sharp_datatype sharp_type; + enum sharp_reduce_op op_type; + size_t src_data_size, dst_data_size; + int ret; + + UCC_TL_SHARP_PROFILE_REQUEST_EVENT(coll_task, "sharp_reduce_scatter_start", + 0); + + sharp_type = ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(dt)]; + op_type = ucc_to_sharp_reduce_op[args->op]; + src_data_size = ucc_dt_size(dt) * count * UCC_TL_TEAM_SIZE(team); + dst_data_size = ucc_dt_size(dt) * count; + + if (!UCC_IS_INPLACE(*args)) { + ucc_tl_sharp_mem_register(TASK_CTX(task), team, args->src.info.buffer, + src_data_size, &task->reduce_scatter.s_mem_h); + } + ucc_tl_sharp_mem_register(TASK_CTX(task), team, args->dst.info.buffer, + dst_data_size, &task->reduce_scatter.r_mem_h); + + if (!UCC_IS_INPLACE(*args)) { + reduce_spec.sbuf_desc.buffer.ptr = args->src.info.buffer; + reduce_spec.sbuf_desc.buffer.mem_handle = + task->reduce_scatter.s_mem_h->mr; + reduce_spec.sbuf_desc.mem_type = + ucc_to_sharp_memtype[args->src.info.mem_type]; + } else { + reduce_spec.sbuf_desc.buffer.ptr = args->dst.info.buffer; + reduce_spec.sbuf_desc.buffer.mem_handle = + task->reduce_scatter.r_mem_h->mr; + reduce_spec.sbuf_desc.mem_type = + ucc_to_sharp_memtype[args->dst.info.mem_type]; + } + + reduce_spec.sbuf_desc.buffer.length = src_data_size; + reduce_spec.sbuf_desc.type = SHARP_DATA_BUFFER; + reduce_spec.rbuf_desc.buffer.ptr = args->dst.info.buffer; + reduce_spec.rbuf_desc.buffer.length = dst_data_size; + reduce_spec.rbuf_desc.buffer.mem_handle = task->reduce_scatter.r_mem_h->mr; + reduce_spec.rbuf_desc.type = SHARP_DATA_BUFFER; + reduce_spec.rbuf_desc.mem_type = + ucc_to_sharp_memtype[args->dst.info.mem_type]; + reduce_spec.aggr_mode = SHARP_AGGREGATION_NONE; + reduce_spec.length = count; + reduce_spec.dtype = sharp_type; + reduce_spec.op = op_type; + reduce_spec.offset = 0; + + ret = sharp_coll_do_reduce_scatter_nb(team->sharp_comm, &reduce_spec, + &task->req_handle); + if (ret != SHARP_COLL_SUCCESS) { + tl_error(UCC_TASK_LIB(task), + "sharp_coll_do_reduce_scatter_nb failed:%s", + sharp_coll_strerror(ret)); + coll_task->status = ucc_tl_sharp_status_to_ucc(ret); + return ucc_task_complete(coll_task); + } + coll_task->status = UCC_INPROGRESS; + + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); +} + +ucc_status_t ucc_tl_sharp_reduce_scatter_init(ucc_tl_sharp_task_t *task) +{ + ucc_coll_args_t *args = &TASK_ARGS(task); + + if (!ucc_coll_args_is_predefined_dt(args, UCC_RANK_INVALID)) { + return UCC_ERR_NOT_SUPPORTED; + } + + if ((!UCC_IS_INPLACE(*args) && + ucc_to_sharp_memtype[args->src.info.mem_type] == SHARP_MEM_TYPE_LAST) || + ucc_to_sharp_memtype[args->dst.info.mem_type] == SHARP_MEM_TYPE_LAST || + ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(args->dst.info.datatype)] == + SHARP_DTYPE_NULL || + ucc_to_sharp_reduce_op[args->op] == SHARP_OP_NULL) { + return UCC_ERR_NOT_SUPPORTED; + } + + task->super.post = ucc_tl_sharp_reduce_scatter_start; + task->super.progress = ucc_tl_sharp_collective_progress; + return UCC_OK; +}; +#endif + ucc_status_t ucc_tl_sharp_allreduce_init(ucc_tl_sharp_task_t *task) { ucc_coll_args_t *args = &TASK_ARGS(task); diff --git a/src/components/tl/sharp/tl_sharp_coll.h b/src/components/tl/sharp/tl_sharp_coll.h index 6b12c69900..6557dc56e8 100644 --- a/src/components/tl/sharp/tl_sharp_coll.h +++ b/src/components/tl/sharp/tl_sharp_coll.h @@ -20,4 +20,7 @@ ucc_status_t ucc_tl_sharp_barrier_init(ucc_tl_sharp_task_t *task); ucc_status_t ucc_tl_sharp_bcast_init(ucc_tl_sharp_task_t *task); +#if HAVE_DECL_SHARP_COLL_DO_REDUCE_SCATTER +ucc_status_t ucc_tl_sharp_reduce_scatter_init(ucc_tl_sharp_task_t *task); +#endif #endif diff --git a/src/components/tl/sharp/tl_sharp_team.c b/src/components/tl/sharp/tl_sharp_team.c index 6b8f369c7c..a8bd380936 100644 --- a/src/components/tl/sharp/tl_sharp_team.c +++ b/src/components/tl/sharp/tl_sharp_team.c @@ -234,6 +234,11 @@ ucc_status_t ucc_tl_sharp_coll_init(ucc_base_coll_args_t *coll_args, case UCC_COLL_TYPE_BCAST: status = ucc_tl_sharp_bcast_init(task); break; +#if HAVE_DECL_SHARP_COLL_DO_REDUCE_SCATTER + case UCC_COLL_TYPE_REDUCE_SCATTER: + status = ucc_tl_sharp_reduce_scatter_init(task); + break; +#endif default: tl_debug(UCC_TASK_LIB(task), "collective %d is not supported by sharp tl",