Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/UCP: reduce dbt #888

Merged
merged 3 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions src/coll_patterns/double_binary_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ typedef struct ucc_dbt_single_tree {
ucc_rank_t root;
ucc_rank_t parent;
ucc_rank_t children[2];
int n_children;
int height;
int recv;
} ucc_dbt_single_tree_t;
Expand Down Expand Up @@ -86,6 +87,21 @@ static inline void get_children(ucc_rank_t size, ucc_rank_t rank, int height,
*r_c = get_right_child(size, rank, height, root);
}

static inline int get_n_children(ucc_rank_t l_c, ucc_rank_t r_c)
{
int n_children = 0;

if (l_c != UCC_RANK_INVALID) {
n_children++;
}

if (r_c != UCC_RANK_INVALID) {
n_children++;
}

return n_children;
}

static inline ucc_rank_t get_parent(int vsize, int vrank, int height, int troot)
{
if (vrank == troot) {
Expand Down Expand Up @@ -121,6 +137,8 @@ static inline void ucc_dbt_build_t2_mirror(ucc_dbt_single_tree_t t1,
t.children[RIGHT_CHILD] = (t1.children[LEFT_CHILD] == UCC_RANK_INVALID) ?
UCC_RANK_INVALID :
size - 1 - t1.children[LEFT_CHILD];
t.n_children = get_n_children(t.children[LEFT_CHILD],
t.children[RIGHT_CHILD]);
t.recv = 0;

*t2 = t;
Expand All @@ -144,6 +162,8 @@ static inline void ucc_dbt_build_t2_shift(ucc_dbt_single_tree_t t1,
t.children[RIGHT_CHILD] = (t1.children[RIGHT_CHILD] == UCC_RANK_INVALID) ?
UCC_RANK_INVALID :
(t1.children[RIGHT_CHILD] + 1) % size;
t.n_children = get_n_children(t.children[LEFT_CHILD],
t.children[RIGHT_CHILD]);
t.recv = 0;

*t2 = t;
Expand All @@ -158,12 +178,14 @@ static inline void ucc_dbt_build_t1(ucc_rank_t rank, ucc_rank_t size,

get_children(size, rank, height, root, &t1->children[LEFT_CHILD],
&t1->children[RIGHT_CHILD]);
t1->height = height;
t1->parent = parent;
t1->size = size;
t1->rank = rank;
t1->root = root;
t1->recv = 0;
t1->n_children = get_n_children(t1->children[LEFT_CHILD],
t1->children[RIGHT_CHILD]);
t1->height = height;
t1->parent = parent;
t1->size = size;
t1->rank = rank;
t1->root = root;
t1->recv = 0;
}

static inline ucc_rank_t ucc_dbt_convert_rank_for_shift(ucc_rank_t rank,
Expand Down
6 changes: 4 additions & 2 deletions src/components/tl/ucp/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ allreduce = \
allreduce/allreduce.h \
allreduce/allreduce.c \
allreduce/allreduce_knomial.c \
allreduce/allreduce_sra_knomial.c
allreduce/allreduce_sra_knomial.c \
allreduce/allreduce_dbt.c

barrier = \
barrier/barrier.h \
Expand Down Expand Up @@ -74,7 +75,8 @@ gatherv = \
reduce = \
reduce/reduce.h \
reduce/reduce.c \
reduce/reduce_knomial.c
reduce/reduce_knomial.c \
reduce/reduce_dbt.c

reduce_scatter = \
reduce_scatter/reduce_scatter.h \
Expand Down
5 changes: 5 additions & 0 deletions src/components/tl/ucp/allreduce/allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ ucc_base_coll_alg_info_t
.name = "sra_knomial",
.desc = "recursive knomial scatter-reduce followed by knomial "
"allgather (optimized for BW)"},
[UCC_TL_UCP_ALLREDUCE_ALG_DBT] =
{.id = UCC_TL_UCP_ALLREDUCE_ALG_SRA_KNOMIAL,
.name = "dbt",
.desc = "alreduce over double binary tree where a leaf in one tree "
"will be intermediate in other (optimized for BW)"},
[UCC_TL_UCP_ALLREDUCE_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

Expand Down
17 changes: 13 additions & 4 deletions src/components/tl/ucp/allreduce/allreduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
enum {
UCC_TL_UCP_ALLREDUCE_ALG_KNOMIAL,
UCC_TL_UCP_ALLREDUCE_ALG_SRA_KNOMIAL,
UCC_TL_UCP_ALLREDUCE_ALG_DBT,
UCC_TL_UCP_ALLREDUCE_ALG_LAST
};

Expand All @@ -36,8 +37,8 @@ ucc_status_t ucc_tl_ucp_allreduce_init(ucc_tl_ucp_task_t *task);
CHECK_SAME_MEMTYPE((_args), (_team));

ucc_status_t ucc_tl_ucp_allreduce_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h);
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_allreduce_knomial_init_common(ucc_tl_ucp_task_t *task);

Expand All @@ -48,13 +49,21 @@ void ucc_tl_ucp_allreduce_knomial_progress(ucc_coll_task_t *task);
ucc_status_t ucc_tl_ucp_allreduce_knomial_finalize(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h);
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_start(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allreduce_dbt_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_allreduce_dbt_start(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allreduce_dbt_progress(ucc_coll_task_t *task);

static inline int ucc_tl_ucp_allreduce_alg_from_str(const char *str)
{
int i;
Expand Down
94 changes: 94 additions & 0 deletions src/components/tl/ucp/allreduce/allreduce_dbt.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/**
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "config.h"
#include "tl_ucp.h"
#include "allreduce.h"
#include "../reduce/reduce.h"
#include "../bcast/bcast.h"

ucc_status_t ucc_tl_ucp_allreduce_dbt_start(ucc_coll_task_t *coll_task)
{
ucc_schedule_t *schedule = ucc_derived_of(coll_task, ucc_schedule_t);
ucc_coll_args_t *args = &schedule->super.bargs.args;
ucc_coll_task_t *reduce_task, *bcast_task;

reduce_task = schedule->tasks[0];
reduce_task->bargs.args.src.info.buffer = args->src.info.buffer;
reduce_task->bargs.args.dst.info.buffer = args->dst.info.buffer;
reduce_task->bargs.args.src.info.count = args->src.info.count;
reduce_task->bargs.args.dst.info.count = args->dst.info.count;

bcast_task = schedule->tasks[1];
bcast_task->bargs.args.src.info.buffer = args->dst.info.buffer;
bcast_task->bargs.args.src.info.count = args->dst.info.count;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_dbt_start", 0);
return ucc_schedule_start(coll_task);
}

ucc_status_t ucc_tl_ucp_allreduce_dbt_finalize(ucc_coll_task_t *coll_task)
{
ucc_schedule_t *schedule = ucc_derived_of(coll_task, ucc_schedule_t);
ucc_status_t status;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(schedule, "ucp_allreduce_dbt_done", 0);
status = ucc_schedule_finalize(coll_task);
ucc_tl_ucp_put_schedule(schedule);
return status;
}

ucc_status_t ucc_tl_ucp_allreduce_dbt_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_base_coll_args_t args = *coll_args;
ucc_schedule_t *schedule;
ucc_coll_task_t *reduce_task, *bcast_task;
ucc_status_t status;

if (UCC_IS_INPLACE(args.args)) {
return UCC_ERR_NOT_SUPPORTED;
}

status = ucc_tl_ucp_get_schedule(tl_team, coll_args,
(ucc_tl_ucp_schedule_t **)&schedule);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}

args.args.root = 0;
UCC_CHECK_GOTO(ucc_tl_ucp_reduce_dbt_init(&args, team, &reduce_task),
out, status);
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, reduce_task),
out, status);
UCC_CHECK_GOTO(ucc_event_manager_subscribe(&schedule->super,
UCC_EVENT_SCHEDULE_STARTED,
reduce_task,
ucc_task_start_handler),
out, status);

UCC_CHECK_GOTO(ucc_tl_ucp_bcast_dbt_init(&args, team, &bcast_task),
out, status);
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, bcast_task),
out, status);
UCC_CHECK_GOTO(ucc_event_manager_subscribe(reduce_task, UCC_EVENT_COMPLETED,
bcast_task,
ucc_task_start_handler),
out, status);

schedule->super.post = ucc_tl_ucp_allreduce_dbt_start;
schedule->super.progress = NULL;
schedule->super.finalize = ucc_tl_ucp_allreduce_dbt_finalize;
*task_h = &schedule->super;

return UCC_OK;

out:
ucc_tl_ucp_put_schedule(schedule);
return status;
}
4 changes: 2 additions & 2 deletions src/components/tl/ucp/bcast/bcast_sag_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ ucc_tl_ucp_bcast_sag_knomial_finalize(ucc_coll_task_t *coll_task)

ucc_status_t
ucc_tl_ucp_bcast_sag_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
size_t count = coll_args->args.src.info.count;
Expand Down
18 changes: 18 additions & 0 deletions src/components/tl/ucp/reduce/reduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ ucc_base_coll_alg_info_t
.name = "knomial",
.desc = "reduce over knomial tree with arbitrary radix "
"(optimized for latency)"},
[UCC_TL_UCP_REDUCE_ALG_DBT] =
{.id = UCC_TL_UCP_REDUCE_ALG_DBT,
.name = "dbt",
.desc = "bcast over double binary tree where a leaf in one tree "
Sergei-Lebedev marked this conversation as resolved.
Show resolved Hide resolved
"will be intermediate in other (optimized for BW)"},
[UCC_TL_UCP_REDUCE_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

Expand Down Expand Up @@ -66,3 +71,16 @@ ucc_status_t ucc_tl_ucp_reduce_init(ucc_tl_ucp_task_t *task)

return status;
}

ucc_status_t ucc_tl_ucp_reduce_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_task_t *task;
ucc_status_t status;

task = ucc_tl_ucp_init_task(coll_args, team);
status = ucc_tl_ucp_reduce_init(task);
*task_h = &task->super;
return status;
}
24 changes: 24 additions & 0 deletions src/components/tl/ucp/reduce/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@

enum {
UCC_TL_UCP_REDUCE_ALG_KNOMIAL,
UCC_TL_UCP_REDUCE_ALG_DBT,
UCC_TL_UCP_REDUCE_ALG_LAST
};

extern ucc_base_coll_alg_info_t
ucc_tl_ucp_reduce_algs[UCC_TL_UCP_REDUCE_ALG_LAST + 1];

#define UCC_TL_UCP_REDUCE_DEFAULT_ALG_SELECT_STR \
"reduce:0-inf:@0"

/* A set of convenience macros used to implement sw based progress
of the reduce algorithm that uses kn pattern */
enum {
Expand All @@ -36,12 +40,32 @@ enum {
}; \
} while (0)


static inline int ucc_tl_ucp_reduce_alg_from_str(const char *str)
{
int i;
for (i = 0; i < UCC_TL_UCP_REDUCE_ALG_LAST; i++) {
if (0 == strcasecmp(str, ucc_tl_ucp_reduce_algs[i].name)) {
break;
}
}
return i;
}

ucc_status_t ucc_tl_ucp_reduce_init(ucc_tl_ucp_task_t *task);

ucc_status_t ucc_tl_ucp_reduce_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_reduce_knomial_start(ucc_coll_task_t *task);

void ucc_tl_ucp_reduce_knomial_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_reduce_knomial_finalize(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_reduce_dbt_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

#endif
Loading
Loading