From 15bef23e4c997e8c036da0c2946d1198c11db146 Mon Sep 17 00:00:00 2001 From: nsarkauskas Date: Wed, 4 Dec 2024 16:54:50 -0800 Subject: [PATCH] pr feedback --- src/components/cl/hier/Makefile.am | 68 +++++++------- .../cl/hier/allgatherv/allgatherv.c | 94 +++++++++---------- src/components/cl/hier/allgatherv/unpack.c | 2 +- src/components/cl/hier/cl_hier.h | 9 ++ src/components/cl/hier/cl_hier_team.c | 7 ++ test/gtest/coll/test_allgatherv.cc | 2 +- 6 files changed, 95 insertions(+), 87 deletions(-) diff --git a/src/components/cl/hier/Makefile.am b/src/components/cl/hier/Makefile.am index ed86a4e395..705dd08b78 100644 --- a/src/components/cl/hier/Makefile.am +++ b/src/components/cl/hier/Makefile.am @@ -3,54 +3,54 @@ # allgatherv = \ - allgatherv/unpack.h \ - allgatherv/unpack.c \ - allgatherv/allgatherv.h \ - allgatherv/allgatherv.c + allgatherv/unpack.h \ + allgatherv/unpack.c \ + allgatherv/allgatherv.h \ + allgatherv/allgatherv.c allreduce = \ - allreduce/allreduce.h \ - allreduce/allreduce.c \ - allreduce/allreduce_rab.c \ - allreduce/allreduce_split_rail.c + allreduce/allreduce.h \ + allreduce/allreduce.c \ + allreduce/allreduce_rab.c \ + allreduce/allreduce_split_rail.c alltoallv = \ - alltoallv/alltoallv.h \ - alltoallv/alltoallv.c + alltoallv/alltoallv.h \ + alltoallv/alltoallv.c alltoall = \ - alltoall/alltoall.h \ - alltoall/alltoall.c + alltoall/alltoall.h \ + alltoall/alltoall.c barrier = \ - barrier/barrier.h \ - barrier/barrier.c + barrier/barrier.h \ + barrier/barrier.c bcast = \ - bcast/bcast.h \ - bcast/bcast.c \ - bcast/bcast_2step.c + bcast/bcast.h \ + bcast/bcast.c \ + bcast/bcast_2step.c reduce = \ - reduce/reduce.h \ - reduce/reduce.c \ - reduce/reduce_2step.c + reduce/reduce.h \ + reduce/reduce.c \ + reduce/reduce_2step.c sources = \ - cl_hier.h \ - cl_hier.c \ - cl_hier_lib.c \ - cl_hier_context.c \ - cl_hier_team.c \ - cl_hier_coll.c \ - cl_hier_coll.h \ - $(allgatherv) \ - $(allreduce) \ - $(alltoallv) \ - $(alltoall) \ - $(barrier) \ - $(bcast) \ - $(reduce) + cl_hier.h \ + cl_hier.c \ + cl_hier_lib.c \ + cl_hier_context.c \ + cl_hier_team.c \ + cl_hier_coll.c \ + cl_hier_coll.h \ + $(allgatherv) \ + $(allreduce) \ + $(alltoallv) \ + $(alltoall) \ + $(barrier) \ + $(bcast) \ + $(reduce) module_LTLIBRARIES = libucc_cl_hier.la libucc_cl_hier_la_SOURCES = $(sources) diff --git a/src/components/cl/hier/allgatherv/allgatherv.c b/src/components/cl/hier/allgatherv/allgatherv.c index 0df210218d..81c47dca63 100755 --- a/src/components/cl/hier/allgatherv/allgatherv.c +++ b/src/components/cl/hier/allgatherv/allgatherv.c @@ -41,34 +41,34 @@ static ucc_status_t ucc_cl_hier_allgatherv_finalize(ucc_coll_task_t *task) return status; } -static inline int is_leader(ucc_base_team_t *team, ucc_rank_t rank) +static inline ucc_rank_t find_leader_rank(ucc_base_team_t *team, + ucc_rank_t team_rank) { ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t); + ucc_team_t *core_team = team->params.team; + ucc_rank_t team_size = UCC_CL_TEAM_SIZE(cl_team); ucc_rank_t ldr_sbgp_size = SBGP_SIZE(cl_team, NODE_LEADERS); - ucc_rank_t i; - for (i = 0; i < ldr_sbgp_size; i++) { - if (ucc_ep_map_eval(SBGP_MAP(cl_team, NODE_LEADERS), i) == rank) { - return 1; - } - } - return 0; -} + ucc_rank_t i, j; -/* TODO: is there a better way to do this? */ -static inline ucc_rank_t find_leader_rank(ucc_base_team_t *team, ucc_rank_t team_rank) -{ - ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t); - ucc_team_t *core_team = team->params.team; - ucc_rank_t i; - - for (i = 0; i < UCC_CL_TEAM_SIZE(cl_team); i++) { - if (ucc_team_ranks_on_same_node(i, team_rank, core_team) && - is_leader(team, i)) { - return i; + /* Allocate and populate node_leaders and leader_list */ + if (ucc_unlikely(cl_team->node_leaders == NULL)) { + cl_team->node_leaders = ucc_malloc(sizeof(ucc_rank_t) * team_size); + cl_team->leader_list = ucc_malloc(sizeof(ucc_rank_t) * ldr_sbgp_size); + for (i = 0; i < team_size; i++) { + for (j = 0; j < ldr_sbgp_size; j++) { + ucc_rank_t ldr_team_rank = ucc_ep_map_eval( + SBGP_MAP(cl_team, NODE_LEADERS), j); + if (ucc_team_ranks_on_same_node(i, ldr_team_rank, core_team)) { + cl_team->node_leaders[i] = ldr_team_rank; + cl_team->leader_list[j] = ldr_team_rank; + break; + } + } } } - return UCC_RANK_INVALID; + /* Return team_rank's node leader */ + return cl_team->node_leaders[team_rank]; } UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init, @@ -195,42 +195,33 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init, /* Sum up the counts on each node to get the count for each node leader */ for (i = 0; i < team_size; i++) { ucc_rank_t leader_team_rank = find_leader_rank(team, i); - size_t leader_old_count = - ucc_coll_args_get_count(&args.args, leader_counts, - ucc_ep_map_local_rank( + ucc_rank_t leader_sbgp_rank = ucc_ep_map_local_rank( SBGP_MAP(cl_team, NODE_LEADERS), - leader_team_rank)); - size_t add_count = - ucc_coll_args_get_count(&args.args, - args.args.dst.info_v.counts, i); + leader_team_rank); + size_t leader_old_count = ucc_coll_args_get_count( + &args.args, leader_counts, + leader_sbgp_rank); + size_t add_count = ucc_coll_args_get_count( + &args.args, + args.args.dst.info_v.counts, i); size_t new_count = add_count + leader_old_count; ucc_coll_args_set_count(&args.args, leader_counts, - ucc_ep_map_local_rank( - SBGP_MAP(cl_team, NODE_LEADERS), - leader_team_rank), - new_count); + leader_sbgp_rank, new_count); } - /* - Need to order leader displacements by their team rank, not their leader sbgp rank. - The reason is leaders are not always in the same order as they are in the team - e.g., 2n2ppn - team ranks = 0 1 2 3 with 0 and 2 as leaders - leader sbgp ranks can be 2 0 wrt their team ranks - */ - for (i = 0; i < team_size; i++) { - if (is_leader(team, i)) { - ucc_rank_t leader_sgbp_rank = - ucc_ep_map_local_rank(SBGP_MAP(cl_team, NODE_LEADERS), i); - ucc_coll_args_set_displacement(&args.args, leader_disps, - leader_sgbp_rank, disp_counter); - disp_counter += ucc_coll_args_get_count(&args.args, - leader_counts, - leader_sgbp_rank); - } + for (i = 0; i < leader_sbgp_size; i++) { + ucc_rank_t leader_sgbp_rank = ucc_ep_map_local_rank( + SBGP_MAP(cl_team, NODE_LEADERS), + cl_team->leader_list[i]); + ucc_coll_args_set_displacement(&args.args, leader_disps, + leader_sgbp_rank, disp_counter); + disp_counter += ucc_coll_args_get_count(&args.args, + leader_counts, + leader_sgbp_rank); } args.args.coll_type = UCC_COLL_TYPE_ALLGATHERV; args.args.flags &= ~UCC_COLL_ARGS_FLAG_IN_PLACE; + args.args.flags |= UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER; args.args.src.info.buffer = node_gathered_data; args.args.src.info.count = ucc_coll_args_get_total_count( &args.args, @@ -251,13 +242,14 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init, cl_team->top_sbgp != UCC_HIER_SBGP_NODE) { args = args_old; args.args.coll_type = UCC_COLL_TYPE_BCAST; - args.args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - args.args.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; args.args.root = 0; args.args.src.info.buffer = args_old.args.dst.info_v.buffer; args.args.src.info.count = total_count; args.args.src.info.datatype = args_old.args.dst.info_v.datatype; args.args.src.info.mem_type = args_old.args.dst.info_v.mem_type; + + /* If using tl_shm and the shmem segment size is less than total_count, + this node-level bcast will cause the allgatherv to fail and fall back */ UCC_CHECK_GOTO( ucc_coll_init(SCORE_MAP(cl_team, NODE), &args, &tasks[n_tasks]), free_scratch, status); diff --git a/src/components/cl/hier/allgatherv/unpack.c b/src/components/cl/hier/allgatherv/unpack.c index 3281cc3314..5ef5bb2028 100644 --- a/src/components/cl/hier/allgatherv/unpack.c +++ b/src/components/cl/hier/allgatherv/unpack.c @@ -134,7 +134,7 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_init(ucc_base_coll_args_t *coll_args, ucc_mc_alloc(&cl_schedule->scratch, scratch_size, UCC_MEMORY_TYPE_HOST), free_schedule, status); - schedule->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; + schedule->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; schedule->super.post = ucc_cl_hier_allgatherv_unpack_start; schedule->super.progress = ucc_cl_hier_allgatherv_unpack_progress; schedule->super.finalize = ucc_cl_hier_allgatherv_unpack_finalize; diff --git a/src/components/cl/hier/cl_hier.h b/src/components/cl/hier/cl_hier.h index 489ec41578..21c6bbd228 100644 --- a/src/components/cl/hier/cl_hier.h +++ b/src/components/cl/hier/cl_hier.h @@ -106,6 +106,15 @@ typedef struct ucc_cl_hier_team { ucc_coll_score_t *score; ucc_hier_sbgp_t sbgps[UCC_HIER_SBGP_LAST]; ucc_hier_sbgp_type_t top_sbgp; + /* Array of size team_size, where node_leaders[i] = the rank of i's node + leader */ + ucc_rank_t *node_leaders; + /* Array of size node_leader_sbgp_size, with ranks in terms of the + team, sorted lowest to highest. This is useful for allgatherv. + The reason is the iterating through the node leader sbgp and map eval'ing + the ranks can yield unsorted ranks, e.g. 2n2ppn with ranks 0 and 2 as + leaders, leader 0 could map to rank 2 and leader 1 could map to rank 0 */ + ucc_rank_t *leader_list; } ucc_cl_hier_team_t; UCC_CLASS_DECLARE(ucc_cl_hier_team_t, ucc_base_context_t *, const ucc_base_team_params_t *); diff --git a/src/components/cl/hier/cl_hier_team.c b/src/components/cl/hier/cl_hier_team.c index dd671e14f5..dec986ac9f 100644 --- a/src/components/cl/hier/cl_hier_team.c +++ b/src/components/cl/hier/cl_hier_team.c @@ -203,6 +203,13 @@ ucc_status_t ucc_cl_hier_team_destroy(ucc_base_team_t *cl_team) ucc_hier_sbgp_t *hs; struct ucc_team_team_desc *d; + if (team->node_leaders) { + ucc_free(team->node_leaders); + ucc_free(team->leader_list); + team->node_leaders = NULL; + team->leader_list = NULL; + } + if (NULL == team->team_create_req) { status = ucc_team_multiple_req_alloc(&team->team_create_req, team->n_tl_teams); diff --git a/test/gtest/coll/test_allgatherv.cc b/test/gtest/coll/test_allgatherv.cc index 9ecc164f71..4178850442 100644 --- a/test/gtest/coll/test_allgatherv.cc +++ b/test/gtest/coll/test_allgatherv.cc @@ -340,7 +340,7 @@ INSTANTIATE_TEST_CASE_P( ::testing::Values(1,3,8192), // count ::testing::Values(TEST_INPLACE, TEST_NO_INPLACE), ::testing::Values("knomial", "ring"), - ::testing::Values(false, true)), // dst buf contig + ::testing::Bool()), // dst buf contig [](const testing::TestParamInfo& info) { std::string name; name += ucc_datatype_str(std::get<0>(info.param));