Skip to content

Commit

Permalink
CL/HIER: Use scratch buf for noncontig allgatherv
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarkauskas authored and nsarka committed Dec 13, 2024
1 parent 9ee9a17 commit 0305f14
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 61 deletions.
110 changes: 65 additions & 45 deletions src/components/cl/hier/allgatherv/allgatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
dst.info_v.datatype);
int in_place = 0;
int is_contig = 1;
size_t disp_counter = 0;
ucc_schedule_t *schedule;
ucc_cl_hier_schedule_t *cl_schedule;
ucc_status_t status;
Expand All @@ -113,6 +114,7 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
size_t leader_counts_size;
size_t leader_disps_size;
size_t total_count;
void *buffer;
void *node_gathered_data;

schedule = &ucc_cl_hier_get_schedule(cl_team)->super.super;
Expand All @@ -134,8 +136,10 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
leader_disps_size = leader_sbgp_size * sizeof(ucc_aint_t);
total_count = ucc_coll_args_get_total_count(&args.args,
args.args.dst.info_v.counts, team_size);
scratch_size = node_counts_size + node_disps_size + leader_counts_size
+ leader_disps_size + (total_count * dt_size);
scratch_size = node_counts_size + node_disps_size
+ leader_counts_size + leader_disps_size;
/* If the dst buf isn't contig, allocate and work on a contig scratch buffer */
scratch_size += (is_contig ? 0 : (total_count * dt_size));

UCC_CHECK_GOTO(
ucc_mc_alloc(&cl_schedule->scratch, scratch_size, UCC_MEMORY_TYPE_HOST),
Expand All @@ -146,14 +150,64 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
node_disps = PTR_OFFSET(node_counts, node_counts_size);
leader_counts = PTR_OFFSET(node_disps, node_disps_size);
leader_disps = PTR_OFFSET(leader_counts, leader_counts_size);
node_gathered_data = PTR_OFFSET(leader_disps, leader_disps_size);
if (is_contig) {
buffer = args.args.dst.info_v.buffer;
} else {
buffer = PTR_OFFSET(leader_disps, leader_disps_size);
}
node_gathered_data = NULL;

/* If I'm a node leader, calculate leader_counts, leader_disps, and set the
dst buffer of the gatherv to the right displacements for the in-place
node-leader allgatherv */
if(SBGP_ENABLED(cl_team, NODE) && SBGP_ENABLED(cl_team, NODE_LEADERS)) {
/* Sum up the counts on each node to get the count for each node leader */
for (i = 0; i < team_size; i++) {
ucc_rank_t leader_team_rank = find_leader_rank(team, i);
ucc_rank_t leader_sbgp_rank = ucc_ep_map_local_rank(
SBGP_MAP(cl_team, NODE_LEADERS),
leader_team_rank);
size_t leader_old_count = ucc_coll_args_get_count(
&args.args, leader_counts,
leader_sbgp_rank);
size_t add_count = ucc_coll_args_get_count(
&args.args,
args.args.dst.info_v.counts, i);
size_t new_count = add_count + leader_old_count;
ucc_coll_args_set_count(&args.args, leader_counts,
leader_sbgp_rank, new_count);
}

/* Calculate leader_disps by adding each count to disp_counter to make
a contiguous chunk */
disp_counter = 0;
for (i = 0; i < leader_sbgp_size; i++) {
//NOLINTNEXTLINE
ucc_rank_t leader_sgbp_rank = ucc_ep_map_local_rank(
SBGP_MAP(cl_team, NODE_LEADERS),
cl_team->leader_list[i]); //NOLINT
ucc_coll_args_set_displacement(&args.args, leader_disps,
leader_sgbp_rank, disp_counter);
disp_counter += ucc_coll_args_get_count(&args.args,
leader_counts,
leader_sgbp_rank);
}

node_gathered_data = PTR_OFFSET(buffer,
dt_size *
ucc_coll_args_get_displacement(
&args.args,
leader_disps,
SBGP_RANK(cl_team, NODE_LEADERS))
);
}

if (SBGP_ENABLED(cl_team, NODE)) {
ucc_assert(n_tasks == 0);
if (cl_team->top_sbgp == UCC_HIER_SBGP_NODE) {
args.args.coll_type = UCC_COLL_TYPE_ALLGATHERV;
} else {
size_t disp_counter = 0;
disp_counter = 0;
for (i = 0; i < node_sbgp_size; i++) {
ucc_rank_t team_rank =
ucc_ep_map_eval(SBGP_MAP(cl_team, NODE), i);
Expand Down Expand Up @@ -200,49 +254,12 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,

if (SBGP_ENABLED(cl_team, NODE_LEADERS)) {
ucc_assert(cl_team->top_sbgp == UCC_HIER_SBGP_NODE_LEADERS);
size_t disp_counter = 0;

/* Sum up the counts on each node to get the count for each node leader */
for (i = 0; i < team_size; i++) {
ucc_rank_t leader_team_rank = find_leader_rank(team, i);
ucc_rank_t leader_sbgp_rank = ucc_ep_map_local_rank(
SBGP_MAP(cl_team, NODE_LEADERS),
leader_team_rank);
size_t leader_old_count = ucc_coll_args_get_count(
&args.args, leader_counts,
leader_sbgp_rank);
size_t add_count = ucc_coll_args_get_count(
&args.args,
args.args.dst.info_v.counts, i);
size_t new_count = add_count + leader_old_count;
ucc_coll_args_set_count(&args.args, leader_counts,
leader_sbgp_rank, new_count);
}

for (i = 0; i < leader_sbgp_size; i++) {
//NOLINTNEXTLINE
ucc_rank_t leader_sgbp_rank = ucc_ep_map_local_rank(
SBGP_MAP(cl_team, NODE_LEADERS),
cl_team->leader_list[i]); //NOLINT
ucc_coll_args_set_displacement(&args.args, leader_disps,
leader_sgbp_rank, disp_counter);
disp_counter += ucc_coll_args_get_count(&args.args,
leader_counts,
leader_sgbp_rank);
}
args.args.coll_type = UCC_COLL_TYPE_ALLGATHERV;
args.args.flags &= ~UCC_COLL_ARGS_FLAG_IN_PLACE;
args.args.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
args.args.flags |= UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER;
args.args.src.info.buffer = node_gathered_data;
args.args.src.info.count = ucc_coll_args_get_total_count(
&args.args,
node_counts,
node_sbgp_size);
args.args.src.info.datatype = args.args.dst.info_v.datatype;
args.args.src.info.mem_type = UCC_MEMORY_TYPE_HOST;
args.args.dst.info_v.buffer = buffer;
args.args.dst.info_v.displacements = leader_disps;
args.args.dst.info_v.counts = leader_counts;
args.args.dst.info_v.buffer = args_old.args.dst.info_v.buffer;
UCC_CHECK_GOTO(ucc_coll_init(SCORE_MAP(cl_team, NODE_LEADERS), &args,
&tasks[n_tasks]),
free_scratch, status);
Expand All @@ -253,8 +270,9 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
cl_team->top_sbgp != UCC_HIER_SBGP_NODE) {
args = args_old;
args.args.coll_type = UCC_COLL_TYPE_BCAST;
args.args.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
args.args.root = 0;
args.args.src.info.buffer = args_old.args.dst.info_v.buffer;
args.args.src.info.buffer = buffer;
args.args.src.info.count = total_count;
args.args.src.info.datatype = args_old.args.dst.info_v.datatype;
args.args.src.info.mem_type = args_old.args.dst.info_v.mem_type;
Expand All @@ -267,7 +285,9 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
n_tasks++;

if (!is_contig) {
args = args_old;
args = args_old;
args.args.src.info_v = args.args.dst.info_v;
args.args.src.info_v.buffer = buffer;
UCC_CHECK_GOTO(
ucc_cl_hier_allgatherv_unpack_init(&args, team, &tasks[n_tasks]),
free_scratch, status);
Expand Down
39 changes: 23 additions & 16 deletions src/components/cl/hier/allgatherv/unpack.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,42 +65,49 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_start(ucc_coll_task_t *task)
ucc_ee_executor_task_t **tasks = PTR_OFFSET(
cl_schedule->scratch->addr,
sizeof(ucc_rank_t));
size_t dt_size = ucc_dt_size(
size_t src_dt_size = ucc_dt_size(
args->src.info_v.datatype);
size_t dst_dt_size = ucc_dt_size(
args->dst.info_v.datatype);
ucc_ee_executor_t *exec;
ucc_status_t status;
int i;
size_t disp_counter;
size_t this_rank_count;
size_t src_rank_count;
size_t dst_rank_count;
size_t src_rank_disp;
size_t dst_rank_disp;

UCC_CHECK_GOTO(
ucc_coll_task_get_executor(&schedule->super, &exec),
out, status);
eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;

disp_counter = ucc_coll_args_get_total_count(args,
args->dst.info_v.counts,
team_size);
*n_tasks = 0;

for (i = team_size - 1; i >= 0; i--) {
this_rank_count = ucc_coll_args_get_count(args, args->dst.info_v.counts,
i);
disp_counter -= this_rank_count;
src_rank_disp = 0;

for (i = 0; i < team_size; i++) {
src_rank_count = ucc_coll_args_get_count(args, args->src.info_v.counts,
i);
dst_rank_count = ucc_coll_args_get_count(args, args->dst.info_v.counts,
i);
dst_rank_disp = ucc_coll_args_get_displacement(
args, args->dst.info_v.displacements, i);
ucc_assert(src_rank_count * src_dt_size ==
dst_rank_count * dst_dt_size);
eargs.copy.src = PTR_OFFSET(
args->dst.info_v.buffer, disp_counter * dt_size);
args->src.info_v.buffer,
src_rank_disp * src_dt_size);
eargs.copy.dst = PTR_OFFSET(
args->dst.info_v.buffer,
ucc_coll_args_get_displacement(
args, args->dst.info_v.displacements, i) *
dt_size);
eargs.copy.len = this_rank_count * dt_size;
dst_rank_disp * dst_dt_size);
eargs.copy.len = dst_rank_count * dst_dt_size;
if (eargs.copy.src != eargs.copy.dst) {
UCC_CHECK_GOTO(
ucc_ee_executor_task_post(exec, &eargs, &tasks[*n_tasks]),
out, status);
(*n_tasks)++;
}
src_rank_disp += src_rank_count;
}

schedule->super.status = UCC_INPROGRESS;
Expand Down

0 comments on commit 0305f14

Please sign in to comment.