diff --git a/src/components/cl/hier/allgatherv/allgatherv.c b/src/components/cl/hier/allgatherv/allgatherv.c index 3377dc6313..2c77812f1b 100755 --- a/src/components/cl/hier/allgatherv/allgatherv.c +++ b/src/components/cl/hier/allgatherv/allgatherv.c @@ -102,6 +102,7 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init, dst.info_v.datatype); int in_place = 0; int is_contig = 1; + size_t disp_counter = 0; ucc_schedule_t *schedule; ucc_cl_hier_schedule_t *cl_schedule; ucc_status_t status; @@ -113,6 +114,7 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init, size_t leader_counts_size; size_t leader_disps_size; size_t total_count; + void *buffer; void *node_gathered_data; schedule = &ucc_cl_hier_get_schedule(cl_team)->super.super; @@ -134,8 +136,10 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init, leader_disps_size = leader_sbgp_size * sizeof(ucc_aint_t); total_count = ucc_coll_args_get_total_count(&args.args, args.args.dst.info_v.counts, team_size); - scratch_size = node_counts_size + node_disps_size + leader_counts_size - + leader_disps_size + (total_count * dt_size); + scratch_size = node_counts_size + node_disps_size + + leader_counts_size + leader_disps_size; + /* If the dst buf isn't contig, allocate and work on a contig scratch buffer */ + scratch_size += (is_contig ? 0 : (total_count * dt_size)); UCC_CHECK_GOTO( ucc_mc_alloc(&cl_schedule->scratch, scratch_size, UCC_MEMORY_TYPE_HOST), @@ -146,14 +150,64 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init, node_disps = PTR_OFFSET(node_counts, node_counts_size); leader_counts = PTR_OFFSET(node_disps, node_disps_size); leader_disps = PTR_OFFSET(leader_counts, leader_counts_size); - node_gathered_data = PTR_OFFSET(leader_disps, leader_disps_size); + if (is_contig) { + buffer = args.args.dst.info_v.buffer; + } else { + buffer = PTR_OFFSET(leader_disps, leader_disps_size); + } + node_gathered_data = NULL; + + /* If I'm a node leader, calculate leader_counts, leader_disps, and set the + dst buffer of the gatherv to the right displacements for the in-place + node-leader allgatherv */ + if(SBGP_ENABLED(cl_team, NODE) && SBGP_ENABLED(cl_team, NODE_LEADERS)) { + /* Sum up the counts on each node to get the count for each node leader */ + for (i = 0; i < team_size; i++) { + ucc_rank_t leader_team_rank = find_leader_rank(team, i); + ucc_rank_t leader_sbgp_rank = ucc_ep_map_local_rank( + SBGP_MAP(cl_team, NODE_LEADERS), + leader_team_rank); + size_t leader_old_count = ucc_coll_args_get_count( + &args.args, leader_counts, + leader_sbgp_rank); + size_t add_count = ucc_coll_args_get_count( + &args.args, + args.args.dst.info_v.counts, i); + size_t new_count = add_count + leader_old_count; + ucc_coll_args_set_count(&args.args, leader_counts, + leader_sbgp_rank, new_count); + } + + /* Calculate leader_disps by adding each count to disp_counter to make + a contiguous chunk */ + disp_counter = 0; + for (i = 0; i < leader_sbgp_size; i++) { + //NOLINTNEXTLINE + ucc_rank_t leader_sgbp_rank = ucc_ep_map_local_rank( + SBGP_MAP(cl_team, NODE_LEADERS), + cl_team->leader_list[i]); //NOLINT + ucc_coll_args_set_displacement(&args.args, leader_disps, + leader_sgbp_rank, disp_counter); + disp_counter += ucc_coll_args_get_count(&args.args, + leader_counts, + leader_sgbp_rank); + } + + node_gathered_data = PTR_OFFSET(buffer, + dt_size * + ucc_coll_args_get_displacement( + &args.args, + leader_disps, + SBGP_RANK(cl_team, NODE_LEADERS)) + ); + } if (SBGP_ENABLED(cl_team, NODE)) { ucc_assert(n_tasks == 0); if (cl_team->top_sbgp == UCC_HIER_SBGP_NODE) { args.args.coll_type = UCC_COLL_TYPE_ALLGATHERV; } else { - size_t disp_counter = 0; + disp_counter = 0; for (i = 0; i < node_sbgp_size; i++) { ucc_rank_t team_rank = ucc_ep_map_eval(SBGP_MAP(cl_team, NODE), i); @@ -200,49 +254,12 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init, if (SBGP_ENABLED(cl_team, NODE_LEADERS)) { ucc_assert(cl_team->top_sbgp == UCC_HIER_SBGP_NODE_LEADERS); - size_t disp_counter = 0; - - /* Sum up the counts on each node to get the count for each node leader */ - for (i = 0; i < team_size; i++) { - ucc_rank_t leader_team_rank = find_leader_rank(team, i); - ucc_rank_t leader_sbgp_rank = ucc_ep_map_local_rank( - SBGP_MAP(cl_team, NODE_LEADERS), - leader_team_rank); - size_t leader_old_count = ucc_coll_args_get_count( - &args.args, leader_counts, - leader_sbgp_rank); - size_t add_count = ucc_coll_args_get_count( - &args.args, - args.args.dst.info_v.counts, i); - size_t new_count = add_count + leader_old_count; - ucc_coll_args_set_count(&args.args, leader_counts, - leader_sbgp_rank, new_count); - } - - for (i = 0; i < leader_sbgp_size; i++) { - //NOLINTNEXTLINE - ucc_rank_t leader_sgbp_rank = ucc_ep_map_local_rank( - SBGP_MAP(cl_team, NODE_LEADERS), - cl_team->leader_list[i]); //NOLINT - ucc_coll_args_set_displacement(&args.args, leader_disps, - leader_sgbp_rank, disp_counter); - disp_counter += ucc_coll_args_get_count(&args.args, - leader_counts, - leader_sgbp_rank); - } args.args.coll_type = UCC_COLL_TYPE_ALLGATHERV; - args.args.flags &= ~UCC_COLL_ARGS_FLAG_IN_PLACE; + args.args.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; args.args.flags |= UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER; - args.args.src.info.buffer = node_gathered_data; - args.args.src.info.count = ucc_coll_args_get_total_count( - &args.args, - node_counts, - node_sbgp_size); - args.args.src.info.datatype = args.args.dst.info_v.datatype; - args.args.src.info.mem_type = UCC_MEMORY_TYPE_HOST; + args.args.dst.info_v.buffer = buffer; args.args.dst.info_v.displacements = leader_disps; args.args.dst.info_v.counts = leader_counts; - args.args.dst.info_v.buffer = args_old.args.dst.info_v.buffer; UCC_CHECK_GOTO(ucc_coll_init(SCORE_MAP(cl_team, NODE_LEADERS), &args, &tasks[n_tasks]), free_scratch, status); @@ -253,8 +270,9 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init, cl_team->top_sbgp != UCC_HIER_SBGP_NODE) { args = args_old; args.args.coll_type = UCC_COLL_TYPE_BCAST; + args.args.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; args.args.root = 0; - args.args.src.info.buffer = args_old.args.dst.info_v.buffer; + args.args.src.info.buffer = buffer; args.args.src.info.count = total_count; args.args.src.info.datatype = args_old.args.dst.info_v.datatype; args.args.src.info.mem_type = args_old.args.dst.info_v.mem_type; @@ -267,7 +285,9 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init, n_tasks++; if (!is_contig) { - args = args_old; + args = args_old; + args.args.src.info_v = args.args.dst.info_v; + args.args.src.info_v.buffer = buffer; UCC_CHECK_GOTO( ucc_cl_hier_allgatherv_unpack_init(&args, team, &tasks[n_tasks]), free_scratch, status); diff --git a/src/components/cl/hier/allgatherv/unpack.c b/src/components/cl/hier/allgatherv/unpack.c index 490cb0ba3d..50424fb8e0 100644 --- a/src/components/cl/hier/allgatherv/unpack.c +++ b/src/components/cl/hier/allgatherv/unpack.c @@ -65,42 +65,49 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_start(ucc_coll_task_t *task) ucc_ee_executor_task_t **tasks = PTR_OFFSET( cl_schedule->scratch->addr, sizeof(ucc_rank_t)); - size_t dt_size = ucc_dt_size( + size_t src_dt_size = ucc_dt_size( + args->src.info_v.datatype); + size_t dst_dt_size = ucc_dt_size( args->dst.info_v.datatype); ucc_ee_executor_t *exec; ucc_status_t status; int i; - size_t disp_counter; - size_t this_rank_count; + size_t src_rank_count; + size_t dst_rank_count; + size_t src_rank_disp; + size_t dst_rank_disp; UCC_CHECK_GOTO( ucc_coll_task_get_executor(&schedule->super, &exec), out, status); eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; - disp_counter = ucc_coll_args_get_total_count(args, - args->dst.info_v.counts, - team_size); *n_tasks = 0; - - for (i = team_size - 1; i >= 0; i--) { - this_rank_count = ucc_coll_args_get_count(args, args->dst.info_v.counts, - i); - disp_counter -= this_rank_count; + src_rank_disp = 0; + + for (i = 0; i < team_size; i++) { + src_rank_count = ucc_coll_args_get_count(args, args->src.info_v.counts, + i); + dst_rank_count = ucc_coll_args_get_count(args, args->dst.info_v.counts, + i); + dst_rank_disp = ucc_coll_args_get_displacement( + args, args->dst.info_v.displacements, i); + ucc_assert(src_rank_count * src_dt_size == + dst_rank_count * dst_dt_size); eargs.copy.src = PTR_OFFSET( - args->dst.info_v.buffer, disp_counter * dt_size); + args->src.info_v.buffer, + src_rank_disp * src_dt_size); eargs.copy.dst = PTR_OFFSET( args->dst.info_v.buffer, - ucc_coll_args_get_displacement( - args, args->dst.info_v.displacements, i) * - dt_size); - eargs.copy.len = this_rank_count * dt_size; + dst_rank_disp * dst_dt_size); + eargs.copy.len = dst_rank_count * dst_dt_size; if (eargs.copy.src != eargs.copy.dst) { UCC_CHECK_GOTO( ucc_ee_executor_task_post(exec, &eargs, &tasks[*n_tasks]), out, status); (*n_tasks)++; } + src_rank_disp += src_rank_count; } schedule->super.status = UCC_INPROGRESS;