Skip to content

Commit

Permalink
TL/UCP: reorder ranks for SRA (#834)
Browse files Browse the repository at this point in the history
  • Loading branch information
shimmybalsam authored Oct 2, 2023
1 parent 34ad3ef commit fc60143
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 45 deletions.
48 changes: 32 additions & 16 deletions src/components/tl/ucp/allgather/allgather_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
size_t count = args->dst.info.count;
size_t dt_size = ucc_dt_size(args->dst.info.datatype);
size_t data_size = count * dt_size;
ucc_rank_t size = UCC_TL_TEAM_SIZE(team);
ucc_rank_t size = task->subset.map.ep_num;
ucc_rank_t broot = args->coll_type == UCC_COLL_TYPE_BCAST ?
args->root : 0;
ucc_rank_t rank = VRANK(UCC_TL_TEAM_RANK(team), broot, size);
ucc_rank_t rank = VRANK(task->subset.myrank, broot, size);
size_t local = GET_LOCAL_COUNT(args, size, rank);
void *sbuf;
ptrdiff_t peer_seg_offset, local_seg_offset;
Expand All @@ -67,18 +67,21 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
if (p->type != KN_PATTERN_ALLGATHERX) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(task->allgather_kn.sbuf,
local * dt_size, mem_type,
INV_VRANK(peer, broot, size), team,
task),
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer,broot,size)),
team, task),
task, out);
}
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(rbuf, data_size, mem_type,
INV_VRANK(peer, broot, size), team,
task),
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer,broot,size)),
team, task),
task, out);
}
if ((p->type != KN_PATTERN_ALLGATHERX) && (node_type == KN_NODE_PROXY)) {
peer = ucc_knomial_pattern_get_extra(p, rank);
extra_count = GET_LOCAL_COUNT(args, size, peer);
peer = ucc_ep_map_eval(task->subset.map, peer);
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(task->allgather_kn.sbuf,
local * dt_size), extra_count * dt_size,
mem_type, peer, team, task),
Expand Down Expand Up @@ -113,8 +116,9 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
}
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(sbuf, local_seg_count * dt_size,
mem_type,
INV_VRANK(peer, broot, size), team,
task),
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer, broot, size)),
team, task),
task, out);
}

Expand All @@ -135,7 +139,9 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
UCPCHECK_GOTO(
ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, peer_seg_offset * dt_size),
peer_seg_count * dt_size, mem_type,
INV_VRANK(peer, broot, size), team, task),
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer, broot, size)),
team, task),
task, out);
}
UCC_KN_PHASE_LOOP:
Expand All @@ -149,7 +155,9 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
if (KN_NODE_PROXY == node_type) {
peer = ucc_knomial_pattern_get_extra(p, rank);
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(args->dst.info.buffer, data_size,
mem_type, INV_VRANK(peer, broot, size),
mem_type,
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer, broot, size)),
team, task),
task, out);
}
Expand All @@ -172,10 +180,10 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_coll_type_t ct = args->coll_type;
ucc_rank_t size = UCC_TL_TEAM_SIZE(team);
ucc_rank_t size = task->subset.map.ep_num;
ucc_kn_radix_t radix = task->allgather_kn.p.radix;
ucc_knomial_pattern_t *p = &task->allgather_kn.p;
ucc_rank_t rank = VRANK(UCC_TL_TEAM_RANK(team),
ucc_rank_t rank = VRANK(task->subset.myrank,
ct == UCC_COLL_TYPE_BCAST ?
args->root : 0, size);
ucc_ee_executor_task_args_t eargs = {0};
Expand Down Expand Up @@ -230,14 +238,22 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
ucc_coll_task_t **task_h, ucc_kn_radix_t radix)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_tl_ucp_task_t *task;
ucc_sbgp_t *sbgp;

task = ucc_tl_ucp_init_task(coll_args, team);
if (tl_team->cfg.use_reordering &&
coll_args->args.coll_type == UCC_COLL_TYPE_ALLREDUCE) {
sbgp = ucc_topo_get_sbgp(tl_team->topo, UCC_SBGP_FULL_HOST_ORDERED);
task->subset.myrank = sbgp->group_rank;
task->subset.map = sbgp->map;
}
task->allgather_kn.p.radix = radix;
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
task->super.post = ucc_tl_ucp_allgather_knomial_start;
task->super.progress = ucc_tl_ucp_allgather_knomial_progress;
*task_h = &task->super;
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
task->super.post = ucc_tl_ucp_allgather_knomial_start;
task->super.progress = ucc_tl_ucp_allgather_knomial_progress;
*task_h = &task->super;
return UCC_OK;
}

Expand Down
69 changes: 40 additions & 29 deletions src/components/tl/ucp/reduce_scatter/reduce_scatter_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,44 +54,46 @@ static inline void get_sbuf_rbuf(ucc_tl_ucp_task_t *task, size_t block_count,

void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
ucc_tl_ucp_task_t);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_kn_radix_t radix = task->reduce_scatter_kn.p.radix;
ucc_knomial_pattern_t *p = &task->reduce_scatter_kn.p;
uint8_t node_type = p->node_type;
void *scratch = task->reduce_scatter_kn.scratch;
void *rbuf = args->dst.info.buffer;
ucc_memory_type_t mem_type = args->dst.info.mem_type;
size_t count = args->dst.info.count;
ucc_datatype_t dt = args->dst.info.datatype;
void *sbuf = UCC_IS_INPLACE(*args) ?
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
ucc_tl_ucp_task_t);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_kn_radix_t radix = task->reduce_scatter_kn.p.radix;
ucc_knomial_pattern_t *p = &task->reduce_scatter_kn.p;
uint8_t node_type = p->node_type;
void *scratch = task->reduce_scatter_kn.scratch;
void *rbuf = args->dst.info.buffer;
ucc_memory_type_t mem_type = args->dst.info.mem_type;
size_t count = args->dst.info.count;
ucc_datatype_t dt = args->dst.info.datatype;
void *sbuf = UCC_IS_INPLACE(*args) ?
rbuf : args->src.info.buffer;
size_t dt_size = ucc_dt_size(dt);
size_t data_size = count * dt_size;
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_rank_t size = UCC_TL_TEAM_SIZE(team);
ptrdiff_t peer_seg_offset, local_seg_offset, offset;
ucc_rank_t peer, step_radix;
ucc_status_t status;
size_t dt_size = ucc_dt_size(dt);
size_t data_size = count * dt_size;
ucc_rank_t rank = task->subset.myrank;
ucc_rank_t size = task->subset.map.ep_num;
ptrdiff_t peer_seg_offset, local_seg_offset, offset;
ucc_rank_t peer, step_radix;
ucc_status_t status;
ucc_kn_radix_t loop_step;
size_t block_count, peer_seg_count, local_seg_count;
void *reduce_data, *local_data;
int is_avg;
size_t block_count, peer_seg_count, local_seg_count;
void *reduce_data, *local_data;
int is_avg;

local_seg_count = 0;
block_count = ucc_sra_kn_compute_block_count(count, rank, p);
UCC_KN_REDUCE_GOTO_PHASE(task->reduce_scatter_kn.phase);
if (KN_NODE_EXTRA == node_type) {
peer = ucc_knomial_pattern_get_proxy(p, rank);
peer = ucc_ep_map_eval(task->subset.map,
ucc_knomial_pattern_get_proxy(p, rank));
UCPCHECK_GOTO(
ucc_tl_ucp_send_nb(sbuf, data_size, mem_type, peer, team, task),
task, out);
}

if (KN_NODE_PROXY == node_type) {
peer = ucc_knomial_pattern_get_extra(p, rank);
peer = ucc_ep_map_eval(task->subset.map,
ucc_knomial_pattern_get_extra(p, rank));
UCPCHECK_GOTO(
ucc_tl_ucp_recv_nb(scratch, data_size, mem_type, peer, team, task),
task, out);
Expand Down Expand Up @@ -132,6 +134,7 @@ void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task)
}
ucc_kn_rs_pattern_peer_seg(peer, p, &peer_seg_count,
&peer_seg_offset);
peer = ucc_ep_map_eval(task->subset.map, peer);
UCPCHECK_GOTO(
ucc_tl_ucp_send_nb(PTR_OFFSET(sbuf, peer_seg_offset * dt_size),
peer_seg_count * dt_size, mem_type, peer,
Expand Down Expand Up @@ -201,8 +204,8 @@ ucc_status_t ucc_tl_ucp_reduce_scatter_knomial_start(ucc_coll_task_t *coll_task)
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_rank_t size = UCC_TL_TEAM_SIZE(team);
ucc_rank_t rank = task->subset.myrank;
ucc_rank_t size = task->subset.map.ep_num;
ucc_status_t status;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_reduce_scatter_kn_start",
Expand Down Expand Up @@ -242,23 +245,31 @@ ucc_tl_ucp_reduce_scatter_knomial_init_r(ucc_base_coll_args_t *coll_args,
ucc_kn_radix_t radix)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_rank_t rank = UCC_TL_TEAM_RANK(tl_team);
ucc_rank_t size = UCC_TL_TEAM_SIZE(tl_team);
size_t count = coll_args->args.dst.info.count;
ucc_datatype_t dt = coll_args->args.dst.info.datatype;
size_t dt_size = ucc_dt_size(dt);
ucc_memory_type_t mem_type = coll_args->args.dst.info.mem_type;
ucc_sbgp_t *sbgp;
ucc_tl_ucp_task_t *task;
ucc_status_t status;
size_t max_recv_size, data_size;
ucc_kn_radix_t step_radix;
ucc_rank_t rank, size;

task = ucc_tl_ucp_init_task(coll_args, team);
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
task->super.post = ucc_tl_ucp_reduce_scatter_knomial_start;
task->super.progress = ucc_tl_ucp_reduce_scatter_knomial_progress;
task->super.finalize = ucc_tl_ucp_reduce_scatter_knomial_finalize;

if (tl_team->cfg.use_reordering) {
sbgp = ucc_topo_get_sbgp(tl_team->topo, UCC_SBGP_FULL_HOST_ORDERED);
task->subset.myrank = sbgp->group_rank;
task->subset.map = sbgp->map;
}

rank = task->subset.myrank;
size = task->subset.map.ep_num;
ucc_assert(coll_args->args.src.info.mem_type ==
coll_args->args.dst.info.mem_type);
ucc_knomial_pattern_init(size, rank, radix, &task->reduce_scatter_kn.p);
Expand Down

0 comments on commit fc60143

Please sign in to comment.