Skip to content

Commit

Permalink
TL/UCP: add reorder ranks to reduce_scatter (openucx#820)
Browse files Browse the repository at this point in the history
* TL/UCP: add reorder ranks to reduce_scatter

* TL/UCP: reorder ranks reduce_scatterv and cleanup

* REVIEW: code review fixes
  • Loading branch information
shimmybalsam authored and nsarka committed Oct 24, 2023
1 parent a01535b commit 78458ec
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 73 deletions.
87 changes: 55 additions & 32 deletions src/components/tl/ucp/reduce_scatter/reduce_scatter_ring.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -12,6 +12,8 @@
#include "utils/ucc_coll_utils.h"
#include "utils/ucc_dt_reduce.h"

#define REVERSED_FRAG 1

static inline void send_completion_common(void *request, ucs_status_t status,
void *user_data)
{
Expand Down Expand Up @@ -112,7 +114,10 @@ static void ucc_tl_ucp_reduce_scatter_ring_progress(ucc_coll_task_t *coll_task)

sendto = ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, sendto);
recvfrom = ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, recvfrom);

if (team->cfg.use_reordering) {
sendto = ucc_ep_map_eval(task->subset.map, sendto);
recvfrom = ucc_ep_map_eval(task->subset.map, recvfrom);
}
max_block_size = task->reduce_scatter_ring.max_block_count * dt_size;
busy = task->reduce_scatter_ring.s_scratch_busy;
r_scratch = task->reduce_scatter_ring.scratch;
Expand All @@ -129,8 +134,11 @@ static void ucc_tl_ucp_reduce_scatter_ring_progress(ucc_coll_task_t *coll_task)
reduce_target = s_scratch[id];
step = task->tagged.send_posted;
prevblock = (rank - 1 - step + size) % size;
prevblock =
prevblock =
ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, prevblock);
if (team->cfg.use_reordering) {
prevblock = ucc_ep_map_eval(task->subset.map, prevblock);
}
/* reduction */
ucc_assert(task->tagged.recv_posted == task->tagged.recv_completed);
ucc_assert(task->tagged.recv_posted < size);
Expand Down Expand Up @@ -172,7 +180,9 @@ static void ucc_tl_ucp_reduce_scatter_ring_progress(ucc_coll_task_t *coll_task)
recv_data_from = (rank - 2 - step + size) % size;
recv_data_from =
ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, recv_data_from);

if (team->cfg.use_reordering) {
recv_data_from = ucc_ep_map_eval(task->subset.map, recv_data_from);
}
ucc_ring_frag_count(task, count, recv_data_from, &frag_count);

UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(r_scratch, frag_count * dt_size, mem_type,
Expand All @@ -199,17 +209,18 @@ ucc_tl_ucp_reduce_scatter_ring_start(ucc_coll_task_t *coll_task)
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t size = task->subset.map.ep_num;
ucc_rank_t rank = task->subset.myrank;
ucc_rank_t sendto = (rank + 1) % size;
ucc_rank_t recvfrom = (rank - 1 + size) % size;
size_t count = args->dst.info.count * size;
ucc_datatype_t dt = args->dst.info.datatype;
size_t dt_size = ucc_dt_size(dt);
ucc_memory_type_t mem_type = args->dst.info.mem_type;
void * sbuf = args->src.info.buffer;
int step = 0;
ucc_rank_t sendto = (rank + 1) % size;
ucc_rank_t recvfrom = (rank - 1 + size) % size;
ucc_rank_t recv_block = (rank - 2 - step + size) % size;
ucc_rank_t send_block = (rank - 1 - step + size) % size;
size_t block_offset, frag_count, frag_offset;
void *r_scratch;
ucc_rank_t send_block, recv_block;
ucc_status_t status;

ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
Expand All @@ -223,13 +234,17 @@ ucc_tl_ucp_reduce_scatter_ring_start(ucc_coll_task_t *coll_task)
return status;
}

r_scratch = task->reduce_scatter_ring.scratch;
sendto = ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, sendto);
recvfrom = ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, recvfrom);
r_scratch = task->reduce_scatter_ring.scratch;
recv_block = (rank - 2 - step + size) % size;
send_block = (rank - 1 - step + size) % size;
recv_block = ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, recv_block);
send_block = ucc_ep_map_eval(task->reduce_scatter_ring.inv_map, send_block);
if (team->cfg.use_reordering) {
sendto = ucc_ep_map_eval(task->subset.map, sendto);
recvfrom = ucc_ep_map_eval(task->subset.map, recvfrom);
recv_block = ucc_ep_map_eval(task->subset.map, recv_block);
send_block = ucc_ep_map_eval(task->subset.map, send_block);
}

ucc_ring_frag_count(task, count, recv_block, &frag_count);
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(r_scratch, frag_count * dt_size, mem_type,
Expand All @@ -253,30 +268,35 @@ static ucc_status_t
ucc_tl_ucp_reduce_scatter_ring_finalize(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);

if (task->reduce_scatter_ring.inv_map.type != UCC_EP_MAP_FULL) {
if (task->reduce_scatter_ring.frag == REVERSED_FRAG) {
ucc_ep_map_destroy(&task->reduce_scatter_ring.inv_map);
}
return ucc_tl_ucp_coll_finalize(coll_task);
}

static ucc_status_t ucc_tl_ucp_reduce_scatter_ring_init_subset(
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
ucc_coll_task_t **task_h, ucc_subset_t subset, int n_frags, int frag,
ucc_coll_task_t **task_h, ucc_subset_t *subsets, int n_frags, int frag,
void *scratch, size_t max_block_count)
{
ucc_tl_ucp_task_t *task;
ucc_tl_ucp_team_t *tl_team;
ucc_status_t status;

task = ucc_tl_ucp_init_task(coll_args, team);
tl_team = TASK_TEAM(task);
task->super.post = ucc_tl_ucp_reduce_scatter_ring_start;
task->super.progress = ucc_tl_ucp_reduce_scatter_ring_progress;
task->super.finalize = ucc_tl_ucp_reduce_scatter_ring_finalize;
task->subset = subset;

if (task->subset.map.type != UCC_EP_MAP_FULL) {
status = ucc_ep_map_create_inverse(task->subset.map,
&task->reduce_scatter_ring.inv_map);
task->subset.map = subsets[frag].map;
task->subset.myrank = subsets[frag].myrank;
if (frag == REVERSED_FRAG) {
if (tl_team->cfg.use_reordering) {
task->subset.map = subsets[0].map;
}
status = ucc_ep_map_create_inverse(subsets[frag].map,
&task->reduce_scatter_ring.inv_map,
frag && tl_team->cfg.use_reordering);
if (UCC_OK != status) {
return status;
}
Expand Down Expand Up @@ -330,6 +350,7 @@ ucc_tl_ucp_reduce_scatter_ring_init(ucc_base_coll_args_t *coll_args,
ucc_tl_ucp_schedule_t *tl_schedule;
ucc_schedule_t *schedule;
ucc_coll_task_t *ctask;
ucc_sbgp_t *sbgp;
ucc_status_t status;
ucc_subset_t s[2];
int i, n_subsets;
Expand All @@ -348,32 +369,34 @@ ucc_tl_ucp_reduce_scatter_ring_init(ucc_base_coll_args_t *coll_args,
return status;
}

schedule = &tl_schedule->super.super;
schedule = &tl_schedule->super.super;
/* if count == size then we have 1 elem per rank, not enough
to split into 2 sets */
n_subsets = (bidir && (count > size)) ? 2 : 1;

s[0].myrank = UCC_TL_TEAM_RANK(tl_team);
s[0].map.type = UCC_EP_MAP_FULL;
s[0].map.ep_num = UCC_TL_TEAM_SIZE(tl_team);
n_subsets = (bidir && (count > size)) ? 2 : 1;

s[1].map = ucc_ep_map_create_reverse(UCC_TL_TEAM_SIZE(tl_team));
s[1].myrank = ucc_ep_map_eval(s[1].map, UCC_TL_TEAM_RANK(tl_team));


count_per_set = (count + n_subsets - 1) / n_subsets;
max_segcount = ucc_buffer_block_count(count_per_set, size, 0);
if (tl_team->cfg.use_reordering) {
sbgp = ucc_topo_get_sbgp(tl_team->topo, UCC_SBGP_FULL_HOST_ORDERED);
s[0].myrank = sbgp->group_rank;
s[0].map = sbgp->map;
} else {
s[0].myrank = UCC_TL_TEAM_RANK(tl_team);
s[0].map.type = UCC_EP_MAP_FULL;
s[0].map.ep_num = UCC_TL_TEAM_SIZE(tl_team);
}
s[1].map = ucc_ep_map_create_reverse(UCC_TL_TEAM_SIZE(tl_team));
s[1].myrank = ucc_ep_map_eval(s[1].map, s[0].myrank);
count_per_set = (count + n_subsets - 1) / n_subsets;
max_segcount = ucc_buffer_block_count(count_per_set, size, 0);
/* in flight we can have 2 sends from 2 differnt blocks and 1 recv:
need 3 * max_segcount of scratch per set */
to_alloc_per_set = max_segcount * 3;
UCC_CHECK_GOTO(ucc_mc_alloc(&tl_schedule->scratch_mc_header,
to_alloc_per_set * dt_size * n_subsets,
mem_type),
out, status);

for (i = 0; i < n_subsets; i++) {
UCC_CHECK_GOTO(ucc_tl_ucp_reduce_scatter_ring_init_subset(
coll_args, team, &ctask, s[i], n_subsets, i,
coll_args, team, &ctask, s, n_subsets, i,
PTR_OFFSET(tl_schedule->scratch_mc_header->addr,
to_alloc_per_set * i * dt_size),
max_segcount),
Expand Down
Loading

0 comments on commit 78458ec

Please sign in to comment.