Skip to content

Commit

Permalink
fixes, passing
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarkauskas committed Nov 26, 2024
1 parent 59073fa commit df6394a
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 11 deletions.
77 changes: 67 additions & 10 deletions src/components/cl/hier/allgatherv/allgatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "allgatherv.h"
#include "../cl_hier_coll.h"
#include "core/ucc_team.h"

#define MAX_ALLGATHERV_TASKS 3

Expand All @@ -30,7 +31,9 @@ static ucc_status_t ucc_cl_hier_allgatherv_finalize(ucc_coll_task_t *task)
ucc_cl_hier_schedule_t *cl_schedule =
ucc_derived_of(task, ucc_cl_hier_schedule_t);
ucc_status_t status;

/*
// manual validation
ucc_cl_hier_team_t *cl_team = ucc_derived_of(task->team, ucc_cl_hier_team_t);
ucc_rank_t team_size = UCC_CL_TEAM_SIZE(cl_team);
ucc_rank_t rank = UCC_CL_TEAM_RANK(cl_team);
Expand All @@ -45,7 +48,10 @@ static ucc_status_t ucc_cl_hier_allgatherv_finalize(ucc_coll_task_t *task)
printf("%d ", ((char*)task->bargs.args.dst.info_v.buffer)[i]);
}
printf("\n");
// end manual validation
*/


ucc_mc_free(cl_schedule->scratch);

UCC_CL_HIER_PROFILE_REQUEST_EVENT(task, "cl_hier_allgatherv_finalize",
Expand All @@ -55,21 +61,50 @@ static ucc_status_t ucc_cl_hier_allgatherv_finalize(ucc_coll_task_t *task)
return status;
}

static inline int is_leader(ucc_base_team_t *team, ucc_rank_t rank)
{
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t);
ucc_rank_t leader_sbgp_size = SBGP_SIZE(cl_team, NODE_LEADERS);
ucc_rank_t i;
for (i = 0; i < leader_sbgp_size; i++) {
if (ucc_ep_map_eval(SBGP_MAP(cl_team, NODE_LEADERS), i) == rank) {
return 1;
}
}
return 0;
}

static inline ucc_rank_t find_leader_rank(ucc_base_team_t *team, ucc_rank_t team_rank)
{
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t);
ucc_team_t *core_team = team->params.team;
ucc_rank_t i;

for (i = 0; i < UCC_CL_TEAM_SIZE(cl_team); i++) {
if (ucc_team_ranks_on_same_node(i, team_rank, core_team) && is_leader(team, i)) {
return i;
}
}

return UCC_RANK_INVALID;
}

UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
(coll_args, team, task),
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
ucc_coll_task_t **task)
{
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t);
ucc_coll_task_t *tasks[MAX_ALLGATHERV_TASKS] = {NULL};
//ucc_rank_t rank = UCC_CL_TEAM_RANK(cl_team);
ucc_rank_t rank = UCC_CL_TEAM_RANK(cl_team);
ucc_rank_t node_sbgp_size = SBGP_SIZE(cl_team, NODE);
ucc_rank_t leader_sbgp_size = SBGP_SIZE(cl_team, NODE_LEADERS);
ucc_rank_t team_size = UCC_CL_TEAM_SIZE(cl_team);
ucc_aint_t *node_disps = NULL;
ucc_count_t *node_counts = NULL;
ucc_aint_t *leader_disps = NULL;
ucc_count_t *leader_counts = NULL;
int in_place = 0;
ucc_schedule_t *schedule;
ucc_cl_hier_schedule_t *cl_schedule;
ucc_status_t status;
Expand Down Expand Up @@ -98,7 +133,7 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,

memcpy(&args, coll_args, sizeof(args));
memcpy(&args_old, coll_args, sizeof(args));
args.args.root = 0;
in_place = UCC_IS_INPLACE(args.args);
n_tasks = 0;
UCC_CHECK_GOTO(ucc_schedule_init(schedule, &args, team), out, status);

Expand Down Expand Up @@ -132,7 +167,16 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
disp_counter += ucc_coll_args_get_count(&args.args, node_counts, i);
}

if (in_place) {
args.args.src.info.buffer = PTR_OFFSET(args.args.dst.info_v.buffer, ucc_coll_args_get_displacement(&args.args, args.args.dst.info_v.displacements, rank));
args.args.src.info.count = ucc_coll_args_get_count(&args.args, args.args.dst.info_v.counts, rank);
args.args.src.info.datatype = args.args.dst.info_v.datatype;
args.args.src.info.mem_type = args.args.dst.info_v.mem_type;
}

args.args.coll_type = UCC_COLL_TYPE_GATHERV;
args.args.root = 0;
args.args.flags &= ~UCC_COLL_ARGS_FLAG_IN_PLACE;
args.args.dst.info_v.displacements = node_disps;
args.args.dst.info_v.counts = node_counts;
args.args.dst.info_v.buffer = node_gathered_data;
Expand All @@ -151,25 +195,38 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,

// Sum up the counts on each node to get the count for each node leader
for (i = 0; i < team_size; i++) {
ucc_rank_t leader_team_rank =
ucc_ep_map_local_rank(SBGP_MAP(cl_team, NODE_LEADERS), i);
size_t leader_old_count = ucc_coll_args_get_count(&args.args, leader_counts, leader_team_rank);
ucc_rank_t leader_team_rank = find_leader_rank(team, i);
//printf("team rank %d mapping to leader %d\n", i, leader_team_rank);
size_t leader_old_count = ucc_coll_args_get_count(&args.args, leader_counts, ucc_ep_map_local_rank(SBGP_MAP(cl_team, NODE_LEADERS), leader_team_rank));
size_t add_count = ucc_coll_args_get_count(&args.args, args.args.dst.info_v.counts, i);
size_t new_count = add_count + leader_old_count;
ucc_coll_args_set_count(&args.args, leader_counts, leader_team_rank, new_count);
//printf("set leader count for leader %d (team rank %d) to %ld. map stride=%ld\n", leader_team_rank, i, new_count, SBGP_MAP(cl_team, NODE_LEADERS).strided.stride);
ucc_coll_args_set_count(&args.args, leader_counts, ucc_ep_map_local_rank(SBGP_MAP(cl_team, NODE_LEADERS), leader_team_rank), new_count);
//printf("set leader count (ptr=%p) for leader %d (team rank %d) to %ld. map stride=%ld. leader sbgp size=%d. confirm read value=%ld\n",
//leader_counts, leader_team_rank, i, new_count, SBGP_MAP(cl_team, NODE_LEADERS).strided.stride, leader_sbgp_size, ucc_coll_args_get_count(&args.args, leader_counts, ucc_ep_map_local_rank(SBGP_MAP(cl_team, NODE_LEADERS), leader_team_rank)));
}
for (i = 0; i < leader_sbgp_size; i++) {
ucc_coll_args_set_displacement(&args.args, leader_disps, i, disp_counter);
disp_counter += ucc_coll_args_get_count(&args.args, leader_counts, i);

// Need to order leader displacements by their team rank, not their leader sbgp rank
// The reason is leaders are not always in the same order as they are in the team
// e.g., 2n2ppn
// team ranks = 0 1 2 3 with 0 and 2 as leaders
// leader sbgp ranks can be 2 0 wrt their team ranks
for (i = 0; i < team_size; i++) {
if (is_leader(team, i)) {
ucc_rank_t leader_sgbp_rank = ucc_ep_map_local_rank(SBGP_MAP(cl_team, NODE_LEADERS), i);
ucc_coll_args_set_displacement(&args.args, leader_disps, leader_sgbp_rank, disp_counter);
//printf("set disp for leader %d to %ld\n", i, disp_counter);
disp_counter += ucc_coll_args_get_count(&args.args, leader_counts, leader_sgbp_rank);
}
}
args.args.coll_type = UCC_COLL_TYPE_ALLGATHERV;
args.args.flags &= ~UCC_COLL_ARGS_FLAG_IN_PLACE;
args.args.src.info.buffer = node_gathered_data;
args.args.src.info.count = ucc_coll_args_get_total_count(&args.args, node_counts, node_sbgp_size);
args.args.src.info.datatype = args.args.dst.info_v.datatype;
args.args.src.info.mem_type = UCC_MEMORY_TYPE_HOST;
args.args.dst.info_v.displacements = leader_disps;
args.args.dst.info_v.counts = leader_counts;
args.args.dst.info_v.buffer = args_old.args.dst.info_v.buffer;
//printf("nick: rank %d doing node leader allgatherv. node sbgp preserves_order=%d. node rank=%d, leader rank=%d\n", rank, cl_team->sbgps[UCC_HIER_SBGP_NODE].sbgp->preserves_order, SBGP_RANK(cl_team, NODE), SBGP_RANK(cl_team, NODE_LEADERS));
UCC_CHECK_GOTO(ucc_coll_init(SCORE_MAP(cl_team, NODE_LEADERS), &args,
&tasks[n_tasks]),
Expand Down
2 changes: 1 addition & 1 deletion src/utils/ucc_coll_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ ucc_coll_args_set_count(const ucc_coll_args_t *args, const ucc_count_t *counts,
if (UCC_COLL_ARGS_COUNT64(args)) {
((uint64_t *)counts)[idx] = (uint64_t)val;
} else {
((uint32_t *)counts)[idx] = (uint64_t)val;
((uint32_t *)counts)[idx] = (uint32_t)val;
}
}

Expand Down

0 comments on commit df6394a

Please sign in to comment.