From d7857f5e419e3c15aff309a372c6519fd06be750 Mon Sep 17 00:00:00 2001 From: Shimmy Balsam Date: Thu, 31 Aug 2023 14:08:07 +0300 Subject: [PATCH] TL/UCP: reorder ranks for SRA --- .../tl/ucp/allgather/allgather_knomial.c | 48 ++++++++----- .../reduce_scatter/reduce_scatter_knomial.c | 70 +++++++++++-------- 2 files changed, 73 insertions(+), 45 deletions(-) diff --git a/src/components/tl/ucp/allgather/allgather_knomial.c b/src/components/tl/ucp/allgather/allgather_knomial.c index 94531945e9..d5a760a23a 100644 --- a/src/components/tl/ucp/allgather/allgather_knomial.c +++ b/src/components/tl/ucp/allgather/allgather_knomial.c @@ -45,10 +45,10 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) size_t count = args->dst.info.count; size_t dt_size = ucc_dt_size(args->dst.info.datatype); size_t data_size = count * dt_size; - ucc_rank_t size = UCC_TL_TEAM_SIZE(team); + ucc_rank_t size = task->subset.map.ep_num; ucc_rank_t broot = args->coll_type == UCC_COLL_TYPE_BCAST ? args->root : 0; - ucc_rank_t rank = VRANK(UCC_TL_TEAM_RANK(team), broot, size); + ucc_rank_t rank = VRANK(task->subset.myrank, broot, size); size_t local = GET_LOCAL_COUNT(args, size, rank); void *sbuf; ptrdiff_t peer_seg_offset, local_seg_offset; @@ -67,18 +67,21 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) if (p->type != KN_PATTERN_ALLGATHERX) { UCPCHECK_GOTO(ucc_tl_ucp_send_nb(task->allgather_kn.sbuf, local * dt_size, mem_type, - INV_VRANK(peer, broot, size), team, - task), + ucc_ep_map_eval(task->subset.map, + INV_VRANK(peer,broot,size)), + team, task), task, out); } UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(rbuf, data_size, mem_type, - INV_VRANK(peer, broot, size), team, - task), + ucc_ep_map_eval(task->subset.map, + INV_VRANK(peer,broot,size)), + team, task), task, out); } if ((p->type != KN_PATTERN_ALLGATHERX) && (node_type == KN_NODE_PROXY)) { peer = ucc_knomial_pattern_get_extra(p, rank); extra_count = GET_LOCAL_COUNT(args, size, peer); + peer = ucc_ep_map_eval(task->subset.map, peer); UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(task->allgather_kn.sbuf, local * dt_size), extra_count * dt_size, mem_type, peer, team, task), @@ -113,8 +116,9 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) } UCPCHECK_GOTO(ucc_tl_ucp_send_nb(sbuf, local_seg_count * dt_size, mem_type, - INV_VRANK(peer, broot, size), team, - task), + ucc_ep_map_eval(task->subset.map, + INV_VRANK(peer, broot, size)), + team, task), task, out); } @@ -135,7 +139,9 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) UCPCHECK_GOTO( ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, peer_seg_offset * dt_size), peer_seg_count * dt_size, mem_type, - INV_VRANK(peer, broot, size), team, task), + ucc_ep_map_eval(task->subset.map, + INV_VRANK(peer, broot, size)), + team, task), task, out); } UCC_KN_PHASE_LOOP: @@ -149,7 +155,9 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) if (KN_NODE_PROXY == node_type) { peer = ucc_knomial_pattern_get_extra(p, rank); UCPCHECK_GOTO(ucc_tl_ucp_send_nb(args->dst.info.buffer, data_size, - mem_type, INV_VRANK(peer, broot, size), + mem_type, + ucc_ep_map_eval(task->subset.map, + INV_VRANK(peer, broot, size)), team, task), task, out); } @@ -172,10 +180,10 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task) ucc_coll_args_t *args = &TASK_ARGS(task); ucc_tl_ucp_team_t *team = TASK_TEAM(task); ucc_coll_type_t ct = args->coll_type; - ucc_rank_t size = UCC_TL_TEAM_SIZE(team); + ucc_rank_t size = task->subset.map.ep_num; ucc_kn_radix_t radix = task->allgather_kn.p.radix; ucc_knomial_pattern_t *p = &task->allgather_kn.p; - ucc_rank_t rank = VRANK(UCC_TL_TEAM_RANK(team), + ucc_rank_t rank = VRANK(task->subset.myrank, ct == UCC_COLL_TYPE_BCAST ? args->root : 0, size); ucc_ee_executor_task_args_t eargs = {0}; @@ -230,14 +238,22 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r( ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h, ucc_kn_radix_t radix) { + ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); ucc_tl_ucp_task_t *task; + ucc_sbgp_t *sbgp; task = ucc_tl_ucp_init_task(coll_args, team); + if (tl_team->cfg.use_reordering && + coll_args->args.coll_type == UCC_COLL_TYPE_ALLREDUCE) { + sbgp = ucc_topo_get_sbgp(tl_team->topo, UCC_SBGP_FULL_HOST_ORDERED); + task->subset.myrank = sbgp->group_rank; + task->subset.map = sbgp->map; + } task->allgather_kn.p.radix = radix; - task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; - task->super.post = ucc_tl_ucp_allgather_knomial_start; - task->super.progress = ucc_tl_ucp_allgather_knomial_progress; - *task_h = &task->super; + task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; + task->super.post = ucc_tl_ucp_allgather_knomial_start; + task->super.progress = ucc_tl_ucp_allgather_knomial_progress; + *task_h = &task->super; return UCC_OK; } diff --git a/src/components/tl/ucp/reduce_scatter/reduce_scatter_knomial.c b/src/components/tl/ucp/reduce_scatter/reduce_scatter_knomial.c index d484ba6892..a70b1a303c 100644 --- a/src/components/tl/ucp/reduce_scatter/reduce_scatter_knomial.c +++ b/src/components/tl/ucp/reduce_scatter/reduce_scatter_knomial.c @@ -54,44 +54,46 @@ static inline void get_sbuf_rbuf(ucc_tl_ucp_task_t *task, size_t block_count, void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, - ucc_tl_ucp_task_t); - ucc_coll_args_t *args = &TASK_ARGS(task); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_kn_radix_t radix = task->reduce_scatter_kn.p.radix; - ucc_knomial_pattern_t *p = &task->reduce_scatter_kn.p; - uint8_t node_type = p->node_type; - void *scratch = task->reduce_scatter_kn.scratch; - void *rbuf = args->dst.info.buffer; - ucc_memory_type_t mem_type = args->dst.info.mem_type; - size_t count = args->dst.info.count; - ucc_datatype_t dt = args->dst.info.datatype; - void *sbuf = UCC_IS_INPLACE(*args) ? + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, + ucc_tl_ucp_task_t); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_kn_radix_t radix = task->reduce_scatter_kn.p.radix; + ucc_knomial_pattern_t *p = &task->reduce_scatter_kn.p; + uint8_t node_type = p->node_type; + void *scratch = task->reduce_scatter_kn.scratch; + void *rbuf = args->dst.info.buffer; + ucc_memory_type_t mem_type = args->dst.info.mem_type; + size_t count = args->dst.info.count; + ucc_datatype_t dt = args->dst.info.datatype; + void *sbuf = UCC_IS_INPLACE(*args) ? rbuf : args->src.info.buffer; - size_t dt_size = ucc_dt_size(dt); - size_t data_size = count * dt_size; - ucc_rank_t rank = UCC_TL_TEAM_RANK(team); - ucc_rank_t size = UCC_TL_TEAM_SIZE(team); - ptrdiff_t peer_seg_offset, local_seg_offset, offset; - ucc_rank_t peer, step_radix; - ucc_status_t status; + size_t dt_size = ucc_dt_size(dt); + size_t data_size = count * dt_size; + ucc_rank_t rank = task->subset.myrank; + ucc_rank_t size = task->subset.map.ep_num; + ptrdiff_t peer_seg_offset, local_seg_offset, offset; + ucc_rank_t peer, step_radix; + ucc_status_t status; ucc_kn_radix_t loop_step; - size_t block_count, peer_seg_count, local_seg_count; - void *reduce_data, *local_data; - int is_avg; + size_t block_count, peer_seg_count, local_seg_count; + void *reduce_data, *local_data; + int is_avg; local_seg_count = 0; block_count = ucc_sra_kn_compute_block_count(count, rank, p); UCC_KN_REDUCE_GOTO_PHASE(task->reduce_scatter_kn.phase); if (KN_NODE_EXTRA == node_type) { - peer = ucc_knomial_pattern_get_proxy(p, rank); + peer = ucc_ep_map_eval(task->subset.map, + ucc_knomial_pattern_get_proxy(p, rank)); UCPCHECK_GOTO( ucc_tl_ucp_send_nb(sbuf, data_size, mem_type, peer, team, task), task, out); } if (KN_NODE_PROXY == node_type) { - peer = ucc_knomial_pattern_get_extra(p, rank); + peer = ucc_ep_map_eval(task->subset.map, + ucc_knomial_pattern_get_extra(p, rank)); UCPCHECK_GOTO( ucc_tl_ucp_recv_nb(scratch, data_size, mem_type, peer, team, task), task, out); @@ -132,6 +134,7 @@ void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task) } ucc_kn_rs_pattern_peer_seg(peer, p, &peer_seg_count, &peer_seg_offset); + peer = ucc_ep_map_eval(task->subset.map, peer); UCPCHECK_GOTO( ucc_tl_ucp_send_nb(PTR_OFFSET(sbuf, peer_seg_offset * dt_size), peer_seg_count * dt_size, mem_type, peer, @@ -201,8 +204,8 @@ ucc_status_t ucc_tl_ucp_reduce_scatter_knomial_start(ucc_coll_task_t *coll_task) ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); ucc_coll_args_t *args = &TASK_ARGS(task); ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_rank_t rank = UCC_TL_TEAM_RANK(team); - ucc_rank_t size = UCC_TL_TEAM_SIZE(team); + ucc_rank_t rank = task->subset.myrank; + ucc_rank_t size = task->subset.map.ep_num; ucc_status_t status; UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_reduce_scatter_kn_start", @@ -242,16 +245,16 @@ ucc_tl_ucp_reduce_scatter_knomial_init_r(ucc_base_coll_args_t *coll_args, ucc_kn_radix_t radix) { ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); - ucc_rank_t rank = UCC_TL_TEAM_RANK(tl_team); - ucc_rank_t size = UCC_TL_TEAM_SIZE(tl_team); size_t count = coll_args->args.dst.info.count; ucc_datatype_t dt = coll_args->args.dst.info.datatype; size_t dt_size = ucc_dt_size(dt); ucc_memory_type_t mem_type = coll_args->args.dst.info.mem_type; + ucc_sbgp_t *sbgp; ucc_tl_ucp_task_t *task; ucc_status_t status; size_t max_recv_size, data_size; ucc_kn_radix_t step_radix; + ucc_rank_t rank, size; task = ucc_tl_ucp_init_task(coll_args, team); task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; @@ -259,6 +262,15 @@ ucc_tl_ucp_reduce_scatter_knomial_init_r(ucc_base_coll_args_t *coll_args, task->super.progress = ucc_tl_ucp_reduce_scatter_knomial_progress; task->super.finalize = ucc_tl_ucp_reduce_scatter_knomial_finalize; + if (tl_team->cfg.use_reordering) { + sbgp = ucc_topo_get_sbgp(tl_team->topo, UCC_SBGP_FULL_HOST_ORDERED); + task->subset.myrank = sbgp->group_rank; + task->subset.map = sbgp->map; + } + + rank = task->subset.myrank; + size = task->subset.map.ep_num; + ucc_assert(coll_args->args.src.info.mem_type == coll_args->args.dst.info.mem_type); ucc_knomial_pattern_init(size, rank, radix, &task->reduce_scatter_kn.p);