Skip to content

Commit

Permalink
add unpack task for noncontig
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarkauskas committed Nov 27, 2024
1 parent d53e608 commit 142070a
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 66 deletions.
94 changes: 48 additions & 46 deletions src/components/cl/hier/Makefile.am
Original file line number Diff line number Diff line change
@@ -1,54 +1,56 @@
#
# Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#

allgatherv = \
allgatherv/allgatherv.h \
allgatherv/allgatherv.c
allgatherv = \
allgatherv/unpack.h \
allgatherv/unpack.c \
allgatherv/allgatherv.h \
allgatherv/allgatherv.c

allreduce = \
allreduce/allreduce.h \
allreduce/allreduce.c \
allreduce/allreduce_rab.c \
allreduce/allreduce_split_rail.c

alltoallv = \
alltoallv/alltoallv.h \
alltoallv/alltoallv.c

alltoall = \
alltoall/alltoall.h \
alltoall/alltoall.c

barrier = \
barrier/barrier.h \
barrier/barrier.c

bcast = \
bcast/bcast.h \
bcast/bcast.c \
bcast/bcast_2step.c

reduce = \
reduce/reduce.h \
reduce/reduce.c \
reduce/reduce_2step.c

sources = \
cl_hier.h \
cl_hier.c \
cl_hier_lib.c \
cl_hier_context.c \
cl_hier_team.c \
cl_hier_coll.c \
cl_hier_coll.h \
$(allgatherv) \
$(allreduce) \
$(alltoallv) \
$(alltoall) \
$(barrier) \
$(bcast) \
$(reduce)
allreduce/allreduce.h \
allreduce/allreduce.c \
allreduce/allreduce_rab.c \
allreduce/allreduce_split_rail.c

alltoallv = \
alltoallv/alltoallv.h \
alltoallv/alltoallv.c

alltoall = \
alltoall/alltoall.h \
alltoall/alltoall.c

barrier = \
barrier/barrier.h \
barrier/barrier.c

bcast = \
bcast/bcast.h \
bcast/bcast.c \
bcast/bcast_2step.c

reduce = \
reduce/reduce.h \
reduce/reduce.c \
reduce/reduce_2step.c

sources = \
cl_hier.h \
cl_hier.c \
cl_hier_lib.c \
cl_hier_context.c \
cl_hier_team.c \
cl_hier_coll.c \
cl_hier_coll.h \
$(allgatherv) \
$(allreduce) \
$(alltoallv) \
$(alltoall) \
$(barrier) \
$(bcast) \
$(reduce)

module_LTLIBRARIES = libucc_cl_hier.la
libucc_cl_hier_la_SOURCES = $(sources)
Expand Down
28 changes: 8 additions & 20 deletions src/components/cl/hier/allgatherv/allgatherv.c
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
/**
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "allgatherv.h"
#include "unpack.h"
#include "../cl_hier_coll.h"
#include "core/ucc_team.h"

#define MAX_ALLGATHERV_TASKS 3
#define MAX_ALLGATHERV_TASKS 4

ucc_base_coll_alg_info_t
ucc_cl_hier_allgatherv_algs[UCC_CL_HIER_ALLGATHERV_ALG_LAST + 1] = {
Expand Down Expand Up @@ -119,8 +120,6 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
size_t leader_disps_size;
size_t total_count;
void *node_gathered_data;
//ucc_ee_executor_task_args_t eargs = {0};
//ucc_ee_executor_t *exec;

schedule = &ucc_cl_hier_get_schedule(cl_team)->super.super;
if (ucc_unlikely(!schedule)) {
Expand Down Expand Up @@ -230,10 +229,9 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
n_tasks++;
}

args = args_old;

if (SBGP_ENABLED(cl_team, NODE) &&
cl_team->top_sbgp != UCC_HIER_SBGP_NODE) {
args = args_old;
args.args.coll_type = UCC_COLL_TYPE_BCAST;
args.args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
args.args.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
Expand All @@ -249,20 +247,9 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
n_tasks++;

if (!is_contig) {
printf("not contig, scheduling unpack operation\n"); /*
UCC_CHECK_GOTO(ucc_coll_task_get_executor(&schedule->super, &exec), out, status);
eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
size_t disp_counter = ucc_coll_args_get_total_count(&args.args, args.args.dst.info_v.counts, team_size);
for (i = team_size - 1; i >= 0; i++) {
size_t this_rank_count = ucc_coll_args_get_count(&args.args, args.args.dst.info_v.counts, i);
disp_counter -= this_rank_count;
eargs.copy.src = PTR_OFFSET(args.args.dst.info_v.buffer, disp_counter * dt_size);
eargs.copy.dst = PTR_OFFSET(args.args.dst.info_v.buffer, ucc_coll_args_get_displacement(&args.args, args.args.dst.info_v.displacements, i) * dt_size);
eargs.copy.len = this_rank_count * dt_size;
ucc_ee_executor_task_t **ee_task_pp = NULL;
UCC_CHECK_GOTO(ucc_ee_executor_task_post(exec, &eargs, &tasks[n_tasks]), out, status);
n_tasks++;
}*/
args = args_old;
UCC_CHECK_GOTO(ucc_cl_hier_allgatherv_unpack_init(&args, team, &tasks[n_tasks]), out, status);
n_tasks++;
}
}

Expand All @@ -279,6 +266,7 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, tasks[i]), out, status);
}

schedule->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
schedule->super.post = ucc_cl_hier_allgatherv_start;
schedule->super.finalize = ucc_cl_hier_allgatherv_finalize;
*task = &schedule->super;
Expand Down
124 changes: 124 additions & 0 deletions src/components/cl/hier/allgatherv/unpack.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/**
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "unpack.h"

ucc_status_t ucc_cl_hier_allgatherv_unpack_finalize(ucc_coll_task_t *task)
{
ucc_schedule_t *schedule = ucc_derived_of(task, ucc_schedule_t);
ucc_cl_hier_schedule_t *cl_schedule = ucc_derived_of(schedule, ucc_cl_hier_schedule_t);

ucc_mc_free(cl_schedule->scratch);

return UCC_OK;
}

void ucc_cl_hier_allgatherv_unpack_progress(ucc_coll_task_t *task)
{
ucc_schedule_t *schedule = ucc_derived_of(task, ucc_schedule_t);
ucc_cl_hier_team_t *cl_team = ucc_derived_of(task->team, ucc_cl_hier_team_t);
ucc_rank_t team_size = UCC_CL_TEAM_SIZE(cl_team);
ucc_cl_hier_schedule_t *cl_schedule = ucc_derived_of(schedule, ucc_cl_hier_schedule_t);
ucc_ee_executor_task_t **tasks = cl_schedule->scratch->addr;
ucc_status_t st;
ucc_rank_t i;

for (i = 0; i < team_size; i++) {
ucc_ee_executor_task_t *etask = tasks[i];
if (etask != NULL) {
st = ucc_ee_executor_task_test(etask);
if (st == UCC_OK) {
ucc_ee_executor_task_finalize(etask);
tasks[i] = NULL;
} else {
if (ucc_likely(st > 0)) {
st = UCC_INPROGRESS;
}
goto out;
}
}
}

out:
schedule->super.status = st;
schedule->super.super.status = st;
}

ucc_status_t ucc_cl_hier_allgatherv_unpack_start(ucc_coll_task_t *task)
{
ucc_schedule_t *schedule = ucc_derived_of(task, ucc_schedule_t);
ucc_cl_hier_team_t *cl_team = ucc_derived_of(task->team, ucc_cl_hier_team_t);
ucc_rank_t team_size = UCC_CL_TEAM_SIZE(cl_team);
ucc_coll_args_t *args = &task->bargs.args;
//ucc_rank_t rank = UCC_CL_TEAM_RANK(cl_team);
ucc_ee_executor_task_args_t eargs = {0};
ucc_cl_hier_schedule_t *cl_schedule = ucc_derived_of(schedule, ucc_cl_hier_schedule_t);
ucc_ee_executor_task_t **tasks = cl_schedule->scratch->addr;
ucc_rank_t n_tasks = 0;
size_t dt_size = ucc_dt_size(args->dst.info_v.datatype);
ucc_ee_executor_t *exec;
ucc_status_t status;
int i;

UCC_CHECK_GOTO(ucc_coll_task_get_executor(&schedule->super, &exec), out, status);
eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
size_t disp_counter = ucc_coll_args_get_total_count(args, args->dst.info_v.counts, team_size);

for (i = team_size - 1; i >= 0; i--) {
size_t this_rank_count = ucc_coll_args_get_count(args, args->dst.info_v.counts, i);
disp_counter -= this_rank_count;
eargs.copy.src = PTR_OFFSET(args->dst.info_v.buffer, disp_counter * dt_size);
eargs.copy.dst = PTR_OFFSET(args->dst.info_v.buffer, ucc_coll_args_get_displacement(args, args->dst.info_v.displacements, i) * dt_size);
eargs.copy.len = this_rank_count * dt_size;
UCC_CHECK_GOTO(ucc_ee_executor_task_post(exec, &eargs, &tasks[n_tasks]), out, status);
n_tasks++;
}

schedule->super.status = UCC_INPROGRESS;
schedule->super.super.status = UCC_INPROGRESS;

ucc_progress_queue_enqueue(cl_team->super.super.context->ucc_context->pq, task);

return UCC_OK;
out:
return status;
}

ucc_status_t ucc_cl_hier_allgatherv_unpack_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t);
ucc_rank_t team_size = UCC_CL_TEAM_SIZE(cl_team);
//ucc_rank_t rank = UCC_CL_TEAM_RANK(cl_team);
ucc_status_t status;
ucc_schedule_t *schedule;
ucc_cl_hier_schedule_t *cl_schedule;
size_t scratch_size;


schedule = &ucc_cl_hier_get_schedule(cl_team)->super.super;
if (ucc_unlikely(!schedule)) {
return UCC_ERR_NO_MEMORY;
}
cl_schedule = ucc_derived_of(schedule, ucc_cl_hier_schedule_t);

UCC_CHECK_GOTO(ucc_schedule_init(schedule, coll_args, team), out, status);

scratch_size = team_size * sizeof(ucc_ee_executor_task_t*);
UCC_CHECK_GOTO(ucc_mc_alloc(&cl_schedule->scratch, scratch_size, UCC_MEMORY_TYPE_HOST), out, status);

schedule->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
schedule->super.post = ucc_cl_hier_allgatherv_unpack_start;
schedule->super.progress = ucc_cl_hier_allgatherv_unpack_progress;
schedule->super.finalize = ucc_cl_hier_allgatherv_unpack_finalize;

*task_h = &schedule->super;
return UCC_OK;
out:
ucc_cl_hier_put_schedule(schedule);
return status;
}
15 changes: 15 additions & 0 deletions src/components/cl/hier/allgatherv/unpack.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/**
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "../cl_hier_coll.h"
#include "core/ucc_team.h"

ucc_status_t ucc_cl_hier_allgatherv_unpack_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);
ucc_status_t ucc_cl_hier_allgatherv_unpack_start(ucc_coll_task_t *task);
void ucc_cl_hier_allgatherv_unpack_progress(ucc_coll_task_t *task);
ucc_status_t ucc_cl_hier_allgatherv_unpack_finalize(ucc_coll_task_t *task);
1 change: 1 addition & 0 deletions src/components/topo/ucc_sbgp.c
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ static ucc_status_t sbgp_create_node_leaders(ucc_topo_t *topo, ucc_sbgp_t *sbgp,
nl_array_3[sbgp_id + host_id * max_ctx_sbgp_size]++;
}

/* Find the first rank that maps to this node, store in nl_array_2 */
if (nl_array_1[host_id] == 0 || nl_array_1[host_id] == ctx_nlr) {
nl_array_2[host_id] = i;
}
Expand Down

0 comments on commit 142070a

Please sign in to comment.