diff --git a/src/coll_patterns/double_binary_tree.h b/src/coll_patterns/double_binary_tree.h index 28e9809c48..baab72936a 100644 --- a/src/coll_patterns/double_binary_tree.h +++ b/src/coll_patterns/double_binary_tree.h @@ -18,6 +18,7 @@ typedef struct ucc_dbt_single_tree { ucc_rank_t root; ucc_rank_t parent; ucc_rank_t children[2]; + int n_children; int height; int recv; } ucc_dbt_single_tree_t; @@ -86,6 +87,21 @@ static inline void get_children(ucc_rank_t size, ucc_rank_t rank, int height, *r_c = get_right_child(size, rank, height, root); } +static inline int get_n_children(ucc_rank_t l_c, ucc_rank_t r_c) +{ + int n_children = 0; + + if (l_c != UCC_RANK_INVALID) { + n_children++; + } + + if (r_c != UCC_RANK_INVALID) { + n_children++; + } + + return n_children; +} + static inline ucc_rank_t get_parent(int vsize, int vrank, int height, int troot) { if (vrank == troot) { @@ -121,6 +137,8 @@ static inline void ucc_dbt_build_t2_mirror(ucc_dbt_single_tree_t t1, t.children[RIGHT_CHILD] = (t1.children[LEFT_CHILD] == UCC_RANK_INVALID) ? UCC_RANK_INVALID : size - 1 - t1.children[LEFT_CHILD]; + t.n_children = get_n_children(t.children[LEFT_CHILD], + t.children[RIGHT_CHILD]); t.recv = 0; *t2 = t; @@ -144,6 +162,8 @@ static inline void ucc_dbt_build_t2_shift(ucc_dbt_single_tree_t t1, t.children[RIGHT_CHILD] = (t1.children[RIGHT_CHILD] == UCC_RANK_INVALID) ? UCC_RANK_INVALID : (t1.children[RIGHT_CHILD] + 1) % size; + t.n_children = get_n_children(t.children[LEFT_CHILD], + t.children[RIGHT_CHILD]); t.recv = 0; *t2 = t; @@ -158,12 +178,14 @@ static inline void ucc_dbt_build_t1(ucc_rank_t rank, ucc_rank_t size, get_children(size, rank, height, root, &t1->children[LEFT_CHILD], &t1->children[RIGHT_CHILD]); - t1->height = height; - t1->parent = parent; - t1->size = size; - t1->rank = rank; - t1->root = root; - t1->recv = 0; + t1->n_children = get_n_children(t1->children[LEFT_CHILD], + t1->children[RIGHT_CHILD]); + t1->height = height; + t1->parent = parent; + t1->size = size; + t1->rank = rank; + t1->root = root; + t1->recv = 0; } static inline ucc_rank_t ucc_dbt_convert_rank_for_shift(ucc_rank_t rank, diff --git a/src/components/tl/ucp/Makefile.am b/src/components/tl/ucp/Makefile.am index badc741d99..bf8e40aa6c 100644 --- a/src/components/tl/ucp/Makefile.am +++ b/src/components/tl/ucp/Makefile.am @@ -39,7 +39,8 @@ allreduce = \ allreduce/allreduce.h \ allreduce/allreduce.c \ allreduce/allreduce_knomial.c \ - allreduce/allreduce_sra_knomial.c + allreduce/allreduce_sra_knomial.c \ + allreduce/allreduce_dbt.c barrier = \ barrier/barrier.h \ @@ -74,7 +75,8 @@ gatherv = \ reduce = \ reduce/reduce.h \ reduce/reduce.c \ - reduce/reduce_knomial.c + reduce/reduce_knomial.c \ + reduce/reduce_dbt.c reduce_scatter = \ reduce_scatter/reduce_scatter.h \ diff --git a/src/components/tl/ucp/allreduce/allreduce.c b/src/components/tl/ucp/allreduce/allreduce.c index 90e9d6bf48..4cda3765b8 100644 --- a/src/components/tl/ucp/allreduce/allreduce.c +++ b/src/components/tl/ucp/allreduce/allreduce.c @@ -20,6 +20,11 @@ ucc_base_coll_alg_info_t .name = "sra_knomial", .desc = "recursive knomial scatter-reduce followed by knomial " "allgather (optimized for BW)"}, + [UCC_TL_UCP_ALLREDUCE_ALG_DBT] = + {.id = UCC_TL_UCP_ALLREDUCE_ALG_SRA_KNOMIAL, + .name = "dbt", + .desc = "alreduce over double binary tree where a leaf in one tree " + "will be intermediate in other (optimized for BW)"}, [UCC_TL_UCP_ALLREDUCE_ALG_LAST] = { .id = 0, .name = NULL, .desc = NULL}}; diff --git a/src/components/tl/ucp/allreduce/allreduce.h b/src/components/tl/ucp/allreduce/allreduce.h index 250bebc981..8eb75fb999 100644 --- a/src/components/tl/ucp/allreduce/allreduce.h +++ b/src/components/tl/ucp/allreduce/allreduce.h @@ -11,6 +11,7 @@ enum { UCC_TL_UCP_ALLREDUCE_ALG_KNOMIAL, UCC_TL_UCP_ALLREDUCE_ALG_SRA_KNOMIAL, + UCC_TL_UCP_ALLREDUCE_ALG_DBT, UCC_TL_UCP_ALLREDUCE_ALG_LAST }; @@ -36,8 +37,8 @@ ucc_status_t ucc_tl_ucp_allreduce_init(ucc_tl_ucp_task_t *task); CHECK_SAME_MEMTYPE((_args), (_team)); ucc_status_t ucc_tl_ucp_allreduce_knomial_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t * team, - ucc_coll_task_t ** task_h); + ucc_base_team_t *team, + ucc_coll_task_t **task_h); ucc_status_t ucc_tl_ucp_allreduce_knomial_init_common(ucc_tl_ucp_task_t *task); @@ -48,13 +49,21 @@ void ucc_tl_ucp_allreduce_knomial_progress(ucc_coll_task_t *task); ucc_status_t ucc_tl_ucp_allreduce_knomial_finalize(ucc_coll_task_t *task); ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t * team, - ucc_coll_task_t ** task_h); + ucc_base_team_t *team, + ucc_coll_task_t **task_h); ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_start(ucc_coll_task_t *task); ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_progress(ucc_coll_task_t *task); +ucc_status_t ucc_tl_ucp_allreduce_dbt_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h); + +ucc_status_t ucc_tl_ucp_allreduce_dbt_start(ucc_coll_task_t *task); + +ucc_status_t ucc_tl_ucp_allreduce_dbt_progress(ucc_coll_task_t *task); + static inline int ucc_tl_ucp_allreduce_alg_from_str(const char *str) { int i; diff --git a/src/components/tl/ucp/allreduce/allreduce_dbt.c b/src/components/tl/ucp/allreduce/allreduce_dbt.c new file mode 100644 index 0000000000..709f4e5f43 --- /dev/null +++ b/src/components/tl/ucp/allreduce/allreduce_dbt.c @@ -0,0 +1,94 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "config.h" +#include "tl_ucp.h" +#include "allreduce.h" +#include "../reduce/reduce.h" +#include "../bcast/bcast.h" + +ucc_status_t ucc_tl_ucp_allreduce_dbt_start(ucc_coll_task_t *coll_task) +{ + ucc_schedule_t *schedule = ucc_derived_of(coll_task, ucc_schedule_t); + ucc_coll_args_t *args = &schedule->super.bargs.args; + ucc_coll_task_t *reduce_task, *bcast_task; + + reduce_task = schedule->tasks[0]; + reduce_task->bargs.args.src.info.buffer = args->src.info.buffer; + reduce_task->bargs.args.dst.info.buffer = args->dst.info.buffer; + reduce_task->bargs.args.src.info.count = args->src.info.count; + reduce_task->bargs.args.dst.info.count = args->dst.info.count; + + bcast_task = schedule->tasks[1]; + bcast_task->bargs.args.src.info.buffer = args->dst.info.buffer; + bcast_task->bargs.args.src.info.count = args->dst.info.count; + + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_dbt_start", 0); + return ucc_schedule_start(coll_task); +} + +ucc_status_t ucc_tl_ucp_allreduce_dbt_finalize(ucc_coll_task_t *coll_task) +{ + ucc_schedule_t *schedule = ucc_derived_of(coll_task, ucc_schedule_t); + ucc_status_t status; + + UCC_TL_UCP_PROFILE_REQUEST_EVENT(schedule, "ucp_allreduce_dbt_done", 0); + status = ucc_schedule_finalize(coll_task); + ucc_tl_ucp_put_schedule(schedule); + return status; +} + +ucc_status_t ucc_tl_ucp_allreduce_dbt_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h) +{ + ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); + ucc_base_coll_args_t args = *coll_args; + ucc_schedule_t *schedule; + ucc_coll_task_t *reduce_task, *bcast_task; + ucc_status_t status; + + if (UCC_IS_INPLACE(args.args)) { + return UCC_ERR_NOT_SUPPORTED; + } + + status = ucc_tl_ucp_get_schedule(tl_team, coll_args, + (ucc_tl_ucp_schedule_t **)&schedule); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + + args.args.root = 0; + UCC_CHECK_GOTO(ucc_tl_ucp_reduce_dbt_init(&args, team, &reduce_task), + out, status); + UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, reduce_task), + out, status); + UCC_CHECK_GOTO(ucc_event_manager_subscribe(&schedule->super, + UCC_EVENT_SCHEDULE_STARTED, + reduce_task, + ucc_task_start_handler), + out, status); + + UCC_CHECK_GOTO(ucc_tl_ucp_bcast_dbt_init(&args, team, &bcast_task), + out, status); + UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, bcast_task), + out, status); + UCC_CHECK_GOTO(ucc_event_manager_subscribe(reduce_task, UCC_EVENT_COMPLETED, + bcast_task, + ucc_task_start_handler), + out, status); + + schedule->super.post = ucc_tl_ucp_allreduce_dbt_start; + schedule->super.progress = NULL; + schedule->super.finalize = ucc_tl_ucp_allreduce_dbt_finalize; + *task_h = &schedule->super; + + return UCC_OK; + +out: + ucc_tl_ucp_put_schedule(schedule); + return status; +} diff --git a/src/components/tl/ucp/bcast/bcast_sag_knomial.c b/src/components/tl/ucp/bcast/bcast_sag_knomial.c index 1fa56a7367..3f4a6919f6 100644 --- a/src/components/tl/ucp/bcast/bcast_sag_knomial.c +++ b/src/components/tl/ucp/bcast/bcast_sag_knomial.c @@ -70,8 +70,8 @@ ucc_tl_ucp_bcast_sag_knomial_finalize(ucc_coll_task_t *coll_task) ucc_status_t ucc_tl_ucp_bcast_sag_knomial_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task_h) + ucc_base_team_t *team, + ucc_coll_task_t **task_h) { ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); size_t count = coll_args->args.src.info.count; diff --git a/src/components/tl/ucp/reduce/reduce.c b/src/components/tl/ucp/reduce/reduce.c index 82a9380083..039f9f393b 100644 --- a/src/components/tl/ucp/reduce/reduce.c +++ b/src/components/tl/ucp/reduce/reduce.c @@ -13,6 +13,11 @@ ucc_base_coll_alg_info_t .name = "knomial", .desc = "reduce over knomial tree with arbitrary radix " "(optimized for latency)"}, + [UCC_TL_UCP_REDUCE_ALG_DBT] = + {.id = UCC_TL_UCP_REDUCE_ALG_DBT, + .name = "dbt", + .desc = "bcast over double binary tree where a leaf in one tree " + "will be intermediate in other (optimized for BW)"}, [UCC_TL_UCP_REDUCE_ALG_LAST] = { .id = 0, .name = NULL, .desc = NULL}}; @@ -66,3 +71,16 @@ ucc_status_t ucc_tl_ucp_reduce_init(ucc_tl_ucp_task_t *task) return status; } + +ucc_status_t ucc_tl_ucp_reduce_knomial_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h) +{ + ucc_tl_ucp_task_t *task; + ucc_status_t status; + + task = ucc_tl_ucp_init_task(coll_args, team); + status = ucc_tl_ucp_reduce_init(task); + *task_h = &task->super; + return status; +} diff --git a/src/components/tl/ucp/reduce/reduce.h b/src/components/tl/ucp/reduce/reduce.h index e26c4fdf23..98bc183ff3 100644 --- a/src/components/tl/ucp/reduce/reduce.h +++ b/src/components/tl/ucp/reduce/reduce.h @@ -9,12 +9,16 @@ enum { UCC_TL_UCP_REDUCE_ALG_KNOMIAL, + UCC_TL_UCP_REDUCE_ALG_DBT, UCC_TL_UCP_REDUCE_ALG_LAST }; extern ucc_base_coll_alg_info_t ucc_tl_ucp_reduce_algs[UCC_TL_UCP_REDUCE_ALG_LAST + 1]; +#define UCC_TL_UCP_REDUCE_DEFAULT_ALG_SELECT_STR \ + "reduce:0-inf:@0" + /* A set of convenience macros used to implement sw based progress of the reduce algorithm that uses kn pattern */ enum { @@ -36,12 +40,32 @@ enum { }; \ } while (0) + +static inline int ucc_tl_ucp_reduce_alg_from_str(const char *str) +{ + int i; + for (i = 0; i < UCC_TL_UCP_REDUCE_ALG_LAST; i++) { + if (0 == strcasecmp(str, ucc_tl_ucp_reduce_algs[i].name)) { + break; + } + } + return i; +} + ucc_status_t ucc_tl_ucp_reduce_init(ucc_tl_ucp_task_t *task); +ucc_status_t ucc_tl_ucp_reduce_knomial_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h); + ucc_status_t ucc_tl_ucp_reduce_knomial_start(ucc_coll_task_t *task); void ucc_tl_ucp_reduce_knomial_progress(ucc_coll_task_t *task); ucc_status_t ucc_tl_ucp_reduce_knomial_finalize(ucc_coll_task_t *task); +ucc_status_t ucc_tl_ucp_reduce_dbt_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h); + #endif diff --git a/src/components/tl/ucp/reduce/reduce_dbt.c b/src/components/tl/ucp/reduce/reduce_dbt.c new file mode 100644 index 0000000000..cfa5d2ff22 --- /dev/null +++ b/src/components/tl/ucp/reduce/reduce_dbt.c @@ -0,0 +1,358 @@ +/** + * Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "config.h" +#include "tl_ucp.h" +#include "reduce.h" +#include "core/ucc_progress_queue.h" +#include "tl_ucp_sendrecv.h" +#include "utils/ucc_dt_reduce.h" + +enum { + RECV, + REDUCE, + TEST, + TEST_ROOT, +}; + +#define UCC_REDUCE_DBT_CHECK_STATE(_p) \ + case _p: \ + goto _p; + +#define UCC_REDUCE_DBT_GOTO_STATE(_state) \ + do { \ + switch (_state) { \ + UCC_REDUCE_DBT_CHECK_STATE(REDUCE); \ + UCC_REDUCE_DBT_CHECK_STATE(TEST); \ + UCC_REDUCE_DBT_CHECK_STATE(TEST_ROOT); \ + }; \ + } while (0) + +static void recv_completion_common(void *request, ucs_status_t status, + const ucp_tag_recv_info_t *info, /* NOLINT */ + void *user_data) +{ + ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *)user_data; + if (ucc_unlikely(UCS_OK != status)) { + tl_error(UCC_TASK_LIB(task), "failure in recv completion %s", + ucs_status_string(status)); + task->super.status = ucs_status_to_ucc_status(status); + } + task->tagged.recv_completed++; + if (request) { + ucp_request_free(request); + } +} + +static void recv_completion_1(void *request, ucs_status_t status, + const ucp_tag_recv_info_t *info, /* NOLINT */ + void *user_data) +{ + ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *)user_data; + + task->reduce_dbt.trees[0].recv++; + recv_completion_common(request, status, info, user_data); +} + +static void recv_completion_2(void *request, ucs_status_t status, + const ucp_tag_recv_info_t *info, /* NOLINT */ + void *user_data) +{ + ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *)user_data; + + task->reduce_dbt.trees[1].recv++; + recv_completion_common(request, status, info, user_data); +} + +static inline void single_tree_reduce(ucc_tl_ucp_task_t *task, void *sbuf, + void *rbuf, int n_children, size_t count, + size_t data_size, ucc_datatype_t dt, + ucc_coll_args_t *args, int is_avg) +{ + ucc_status_t status; + + status = ucc_dt_reduce_strided( + sbuf,rbuf, rbuf, + n_children, count, data_size, + dt, args, + is_avg ? UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA : 0, + AVG_ALPHA(task), task->reduce_dbt.executor, + &task->reduce_dbt.etask); + + if (ucc_unlikely(UCC_OK != status)) { + tl_error(UCC_TASK_LIB(task), + "failed to perform dt reduction"); + task->super.status = status; + return; + } + EXEC_TASK_WAIT(task->reduce_dbt.etask); +} + +void ucc_tl_ucp_reduce_dbt_progress(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, + ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_dbt_single_tree_t *trees = task->reduce_dbt.trees ; + ucc_rank_t rank = UCC_TL_TEAM_RANK(team); + ucc_rank_t coll_root = (ucc_rank_t)args->root; + int is_root = rank == coll_root; + ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1, + recv_completion_2}; + void *sbuf[2], *rbuf[2]; + uint32_t i, j, k; + ucc_memory_type_t mtype; + ucc_datatype_t dt; + size_t count, data_size, data_size_t1; + size_t counts[2]; + int avg_pre_op, avg_post_op; + + if (is_root) { + mtype = args->dst.info.mem_type; + dt = args->dst.info.datatype; + count = args->dst.info.count; + } else { + mtype = args->src.info.mem_type; + dt = args->src.info.datatype; + count = args->src.info.count; + } + + counts[0] = (count % 2) ? count / 2 + 1 : count / 2; + counts[1] = count / 2; + data_size = count * ucc_dt_size(dt); + data_size_t1 = counts[0] * ucc_dt_size(dt); + avg_pre_op = ((args->op == UCC_OP_AVG) && + UCC_TL_UCP_TEAM_LIB(team)->cfg.reduce_avg_pre_op); + avg_post_op = ((args->op == UCC_OP_AVG) && + !UCC_TL_UCP_TEAM_LIB(team)->cfg.reduce_avg_pre_op); + + rbuf[0] = task->reduce_dbt.scratch; + rbuf[1] = PTR_OFFSET(rbuf[0], data_size_t1 * 2);; + sbuf[0] = avg_pre_op ? PTR_OFFSET(rbuf[0], data_size * 2) + : args->src.info.buffer;; + sbuf[1] = PTR_OFFSET(sbuf[0], data_size_t1); + + UCC_REDUCE_DBT_GOTO_STATE(task->reduce_dbt.state); + for (i = 0; i < 2; i++) { + j = 0; + for (k = 0; k < 2; k++) { + if (trees[i].children[k] != UCC_RANK_INVALID) { + UCPCHECK_GOTO(ucc_tl_ucp_recv_cb( + PTR_OFFSET(rbuf[i], counts[i] * ucc_dt_size(dt) * j), + counts[i] * ucc_dt_size(dt), mtype, + trees[i].children[k], team, task, cb[i], + (void *)task), + task, out); + j++; + } + + } + } + task->reduce_dbt.state = REDUCE; + +REDUCE: + for (i = 0; i < 2; i++) { + if (trees[i].recv == trees[i].n_children && + !task->reduce_dbt.reduction_comp[i]) { + if (trees[i].n_children > 0) { + single_tree_reduce(task, sbuf[i], rbuf[i], trees[i].n_children, + counts[i], counts[i] * ucc_dt_size(dt), dt, + args, avg_post_op && trees[i].root == rank); + } + task->reduce_dbt.reduction_comp[i] = 1; + } + } + + for (i = 0; i < 2; i++) { + if (rank != trees[i].root && task->reduce_dbt.reduction_comp[i] && + !task->reduce_dbt.send_comp[i]) { + UCPCHECK_GOTO(ucc_tl_ucp_send_nb((trees[i].n_children > 0) ? rbuf[i] + : sbuf[i], + counts[i] * ucc_dt_size(dt), + mtype, trees[i].parent, team, + task), + task, out); + task->reduce_dbt.send_comp[i] = 1; + } + } + + if (!task->reduce_dbt.reduction_comp[0] || + !task->reduce_dbt.reduction_comp[1]) { + return; + } +TEST: + if (UCC_INPROGRESS == ucc_tl_ucp_test_send(task)) { + task->reduce_dbt.state = TEST; + return; + } + + /* tree roots send to coll root*/ + for (i = 0; i < 2; i++) { + if (rank == trees[i].root && !is_root) { + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(rbuf[i], + counts[i] * ucc_dt_size(dt), + mtype, coll_root, team, task), + task, out); + } + } + + task->reduce_dbt.reduction_comp[0] = trees[0].recv; + task->reduce_dbt.reduction_comp[1] = trees[1].recv; + + for (i = 0; i < 2; i++) { + if (is_root && rank != trees[i].root) { + UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(args->dst.info.buffer, + i * counts[0] * ucc_dt_size(dt)), + counts[i] * ucc_dt_size(dt), + mtype, trees[i].root, team, task, + cb[i], (void *)task), + task, out); + task->reduce_dbt.reduction_comp[i]++; + } + } + +TEST_ROOT: + if (UCC_INPROGRESS == ucc_tl_ucp_test_send(task) || + task->reduce_dbt.reduction_comp[0] != trees[0].recv || + task->reduce_dbt.reduction_comp[1] != trees[1].recv) { + task->reduce_dbt.state = TEST_ROOT; + return; + } + + for (i = 0; i < 2; i++) { + if (is_root && rank == trees[i].root) { + UCPCHECK_GOTO(ucc_mc_memcpy(PTR_OFFSET(args->dst.info.buffer, + i * counts[i - 1] * ucc_dt_size(dt)), + rbuf[i], counts[i] * ucc_dt_size(dt), + mtype, mtype), task, out); + } + } + + task->super.status = UCC_OK; + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_reduce_dbt_done", 0); +out: + return; +} + +ucc_status_t ucc_tl_ucp_reduce_dbt_start(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, + ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_rank_t rank = UCC_TL_TEAM_RANK(team); + ucc_rank_t team_size = UCC_TL_TEAM_SIZE(team); + int avg_pre_op = + UCC_TL_UCP_TEAM_LIB(TASK_TEAM(task))->cfg.reduce_avg_pre_op; + ucc_datatype_t dt; + size_t count, data_size; + ucc_status_t status; + + task->reduce_dbt.trees[0].recv = 0; + task->reduce_dbt.trees[1].recv = 0; + task->reduce_dbt.reduction_comp[0] = 0; + task->reduce_dbt.reduction_comp[1] = 0; + task->reduce_dbt.send_comp[0] = 0; + task->reduce_dbt.send_comp[1] = 0; + + ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + + if (args->root == rank) { + count = args->dst.info.count; + dt = args->dst.info.datatype; + } else { + count = args->src.info.count; + dt = args->src.info.datatype; + } + data_size = count * ucc_dt_size(dt); + + status = ucc_coll_task_get_executor(&task->super, + &task->reduce_dbt.executor); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + + if (UCC_IS_INPLACE(*args) && (rank == args->root)) { + args->src.info.buffer = args->dst.info.buffer; + } + + if (avg_pre_op && args->op == UCC_OP_AVG) { + /* In case of avg_pre_op, each process must divide itself by team_size */ + status = + ucc_dt_reduce(args->src.info.buffer, args->src.info.buffer, + PTR_OFFSET(task->reduce_dbt.scratch, data_size * 2), + count, dt, args, UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA, + 1.0 / (double)(team_size * 2), + task->reduce_dbt.executor, &task->reduce_dbt.etask); + if (ucc_unlikely(UCC_OK != status)) { + tl_error(UCC_TASK_LIB(task), + "failed to perform dt reduction"); + return status; + } + EXEC_TASK_WAIT(task->reduce_dbt.etask, status); + } + + task->reduce_dbt.state = RECV; + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_reduce_dbt_start", 0); + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); +} + +ucc_status_t ucc_tl_ucp_reduce_dbt_finalize(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + + if (task->reduce_dbt.scratch_mc_header) { + ucc_mc_free(task->reduce_dbt.scratch_mc_header); + } + + return ucc_tl_ucp_coll_finalize(coll_task); +} + +ucc_status_t ucc_tl_ucp_reduce_dbt_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h) +{ + ucc_tl_ucp_team_t *tl_team; + ucc_tl_ucp_task_t *task; + ucc_rank_t rank, size; + ucc_memory_type_t mtype; + ucc_datatype_t dt; + size_t count; + size_t data_size; + ucc_status_t status; + + task = ucc_tl_ucp_init_task(coll_args, team); + task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; + task->super.post = ucc_tl_ucp_reduce_dbt_start; + task->super.progress = ucc_tl_ucp_reduce_dbt_progress; + task->super.finalize = ucc_tl_ucp_reduce_dbt_finalize; + tl_team = TASK_TEAM(task); + rank = UCC_TL_TEAM_RANK(tl_team); + size = UCC_TL_TEAM_SIZE(tl_team); + ucc_dbt_build_trees(rank, size, &task->reduce_dbt.trees[0], + &task->reduce_dbt.trees[1]); + + if (coll_args->args.root == rank) { + count = coll_args->args.dst.info.count; + dt = coll_args->args.dst.info.datatype; + mtype = coll_args->args.dst.info.mem_type; + } else { + count = coll_args->args.src.info.count; + dt = coll_args->args.src.info.datatype; + mtype = coll_args->args.src.info.mem_type; + } + data_size = count * ucc_dt_size(dt); + task->reduce_dbt.scratch_mc_header = NULL; + status = ucc_mc_alloc(&task->reduce_dbt.scratch_mc_header, 3 * data_size, + mtype); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + task->reduce_dbt.scratch = task->reduce_dbt.scratch_mc_header->addr; + *task_h = &task->super; + return UCC_OK; +} diff --git a/src/components/tl/ucp/tl_ucp_coll.c b/src/components/tl/ucp/tl_ucp_coll.c index a1ba843d9f..23c254b00e 100644 --- a/src/components/tl/ucp/tl_ucp_coll.c +++ b/src/components/tl/ucp/tl_ucp_coll.c @@ -42,6 +42,10 @@ const ucc_tl_ucp_default_alg_desc_t .select_str = UCC_TL_UCP_BCAST_DEFAULT_ALG_SELECT_STR, .str_get_fn = NULL }, + { + .select_str = UCC_TL_UCP_REDUCE_DEFAULT_ALG_SELECT_STR, + .str_get_fn = NULL + }, { .select_str = UCC_TL_UCP_REDUCE_SCATTER_DEFAULT_ALG_SELECT_STR, .str_get_fn = NULL @@ -223,6 +227,8 @@ static inline int alg_id_from_str(ucc_coll_type_t coll_type, const char *str) return ucc_tl_ucp_alltoallv_alg_from_str(str); case UCC_COLL_TYPE_BCAST: return ucc_tl_ucp_bcast_alg_from_str(str); + case UCC_COLL_TYPE_REDUCE: + return ucc_tl_ucp_reduce_alg_from_str(str); case UCC_COLL_TYPE_REDUCE_SCATTER: return ucc_tl_ucp_reduce_scatter_alg_from_str(str); case UCC_COLL_TYPE_REDUCE_SCATTERV: @@ -239,6 +245,7 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str, ucc_base_coll_init_fn_t *init) { ucc_status_t status = UCC_OK; + if (alg_id_str) { alg_id = alg_id_from_str(coll_type, alg_id_str); } @@ -268,6 +275,9 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str, case UCC_TL_UCP_ALLREDUCE_ALG_SRA_KNOMIAL: *init = ucc_tl_ucp_allreduce_sra_knomial_init; break; + case UCC_TL_UCP_ALLREDUCE_ALG_DBT: + *init = ucc_tl_ucp_allreduce_dbt_init; + break; default: status = UCC_ERR_INVALID_PARAM; break; @@ -321,6 +331,19 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str, break; }; break; + case UCC_COLL_TYPE_REDUCE: + switch (alg_id) { + case UCC_TL_UCP_REDUCE_ALG_KNOMIAL: + *init = ucc_tl_ucp_reduce_knomial_init; + break; + case UCC_TL_UCP_REDUCE_ALG_DBT: + *init = ucc_tl_ucp_reduce_dbt_init; + break; + default: + status = UCC_ERR_INVALID_PARAM; + break; + }; + break; case UCC_COLL_TYPE_REDUCE_SCATTER: switch (alg_id) { case UCC_TL_UCP_REDUCE_SCATTER_ALG_RING: diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index 52b87c5e2c..6ab2c661dd 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -17,7 +17,7 @@ #include "tl_ucp_tag.h" #define UCC_UUNITS_AUTO_RADIX 4 -#define UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR 7 +#define UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR 8 ucc_status_t ucc_tl_ucp_team_default_score_str_alloc(ucc_tl_ucp_team_t *team, char *default_select_str[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR]); @@ -201,6 +201,16 @@ typedef struct ucc_tl_ucp_task { ucc_ee_executor_task_t *etask; ucc_ee_executor_t *executor; } reduce_kn; + struct { + int state; + ucc_dbt_single_tree_t trees[2]; + int reduction_comp[2]; + int send_comp[2]; + void *scratch; + ucc_mc_buffer_header_t *scratch_mc_header; + ucc_ee_executor_task_t *etask; + ucc_ee_executor_t *executor; + } reduce_dbt; struct { ucc_rank_t dist; ucc_rank_t max_dist; diff --git a/test/gtest/coll/test_allreduce.cc b/test/gtest/coll/test_allreduce.cc index dba5ecc8b6..7e718cefaa 100644 --- a/test/gtest/coll/test_allreduce.cc +++ b/test/gtest/coll/test_allreduce.cc @@ -327,6 +327,43 @@ TYPED_TEST(test_allreduce_alg, sra_knomial_pipelined) { } } +TYPED_TEST(test_allreduce_alg, dbt) { + int n_procs = 15; + ucc_job_env_t env = {{"UCC_CL_BASIC_TUNE", "inf"}, + {"UCC_TL_UCP_TUNE", "allreduce:@dbt:inf"}}; + UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL, env); + UccTeam_h team = job.create_team(n_procs); + int repeat = 3; + UccCollCtxVec ctxs; + std::vector mt = {UCC_MEMORY_TYPE_HOST}; + + if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA)) { + mt.push_back(UCC_MEMORY_TYPE_CUDA); + } + if (UCC_OK == ucc_mc_available( UCC_MEMORY_TYPE_CUDA_MANAGED)) { + mt.push_back( UCC_MEMORY_TYPE_CUDA_MANAGED); + } + + for (auto count : {65536, 123567}) { + for (auto inplace : {TEST_NO_INPLACE, TEST_INPLACE}) { + for (auto m : mt) { + SET_MEM_TYPE(m); + this->set_inplace(inplace); + this->data_init(n_procs, TypeParam::dt, count, ctxs, true); + UccReq req(team, ctxs); + + for (auto i = 0; i < repeat; i++) { + req.start(); + req.wait(); + EXPECT_EQ(true, this->data_validate(ctxs)); + this->reset(ctxs); + } + this->data_fini(ctxs); + } + } + } +} + TYPED_TEST(test_allreduce_alg, rab) { int n_procs = 15; ucc_job_env_t env = {{"UCC_CL_HIER_TUNE", "allreduce:@rab:0-inf:inf"}, diff --git a/test/gtest/coll/test_reduce.cc b/test/gtest/coll/test_reduce.cc index 393e97decc..0f8bfc034f 100644 --- a/test/gtest/coll/test_reduce.cc +++ b/test/gtest/coll/test_reduce.cc @@ -23,17 +23,9 @@ class test_reduce : public UccCollArgs, public testing::Test { ucc_coll_args_t *coll = (ucc_coll_args_t*) calloc(1, sizeof(ucc_coll_args_t)); - ctxs[r] = (gtest_ucc_coll_ctx_t*)calloc(1, - sizeof(gtest_ucc_coll_ctx_t)); - ctxs[r]->args = coll; - - coll->coll_type = UCC_COLL_TYPE_REDUCE; - coll->op = T::redop; - coll->root = root; - coll->src.info.mem_type = mem_type; - coll->src.info.count = (ucc_count_t)count; - coll->src.info.datatype = dt; - + ctxs[r] = (gtest_ucc_coll_ctx_t*)calloc(1, + sizeof(gtest_ucc_coll_ctx_t)); + ctxs[r]->args = coll; ctxs[r]->init_buf = ucc_malloc(ucc_dt_size(dt) * count, "init buf"); EXPECT_NE(ctxs[r]->init_buf, nullptr); @@ -48,6 +40,21 @@ class test_reduce : public UccCollArgs, public testing::Test { ptr[i] = (typename T::type)((i + r + 1) % 8); } + coll->coll_type = UCC_COLL_TYPE_REDUCE; + coll->op = T::redop; + coll->root = root; + if (r != root || !inplace) { + coll->src.info.mem_type = mem_type; + coll->src.info.count = (ucc_count_t)count; + coll->src.info.datatype = dt; + UCC_CHECK(ucc_mc_alloc(&ctxs[r]->src_mc_header, + ucc_dt_size(dt) * count, mem_type)); + coll->src.info.buffer = ctxs[r]->src_mc_header->addr; + UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, + ctxs[r]->init_buf, + ucc_dt_size(dt) * count, mem_type, + UCC_MEMORY_TYPE_HOST)); + } if (r == root) { coll->dst.info.mem_type = mem_type; coll->dst.info.count = (ucc_count_t)count; @@ -65,15 +72,6 @@ class test_reduce : public UccCollArgs, public testing::Test { coll->mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll->flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; } - if (r != root || !inplace) { - UCC_CHECK(ucc_mc_alloc(&ctxs[r]->src_mc_header, - ucc_dt_size(dt) * count, mem_type)); - coll->src.info.buffer = ctxs[r]->src_mc_header->addr; - UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, - ctxs[r]->init_buf, - ucc_dt_size(dt) * count, mem_type, - UCC_MEMORY_TYPE_HOST)); - } if (persistent) { coll->mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll->flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; @@ -282,42 +280,58 @@ TYPED_TEST(test_reduce_cuda, multiple_inplace_managed) { template class test_reduce_avg_order : public test_reduce { }; +template class test_reduce_dbt : public test_reduce { +}; + +#define TEST_DECLARE_WITH_ENV(_env, _n_procs) \ + { \ + UccJob job(_n_procs, UccJob::UCC_JOB_CTX_GLOBAL, _env); \ + UccTeam_h team = job.create_team(_n_procs); \ + int repeat = 3; \ + UccCollCtxVec ctxs; \ + std::vector mt = {UCC_MEMORY_TYPE_HOST}; \ + if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA)) { \ + mt.push_back(UCC_MEMORY_TYPE_CUDA); \ + } \ + if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA_MANAGED)) { \ + mt.push_back(UCC_MEMORY_TYPE_CUDA_MANAGED); \ + } \ + for (auto count : {5, 256, 65536}) { \ + for (auto inplace : {TEST_NO_INPLACE, TEST_INPLACE}) { \ + for (auto m : mt) { \ + CHECK_TYPE_OP_SKIP(TypeParam::dt, TypeParam::redop, m); \ + SET_MEM_TYPE(m); \ + this->set_inplace(inplace); \ + this->data_init(_n_procs, TypeParam::dt, count, ctxs, true); \ + UccReq req(team, ctxs); \ + CHECK_REQ_NOT_SUPPORTED_SKIP(req, this->data_fini(ctxs)); \ + for (auto i = 0; i < repeat; i++) { \ + req.start(); \ + req.wait(); \ + EXPECT_EQ(true, this->data_validate(ctxs)); \ + this->reset(ctxs); \ + } \ + this->data_fini(ctxs); \ + } \ + } \ + } \ + } + TYPED_TEST_CASE(test_reduce_avg_order, CollReduceTypeOpsAvg); +TYPED_TEST_CASE(test_reduce_dbt, CollReduceTypeOpsHost); -TYPED_TEST(test_reduce_avg_order, avg_post_op) -{ - int n_procs = 15; - ucc_job_env_t env = {{"UCC_TL_UCP_REDUCE_AVG_PRE_OP", "0"}}; - UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL, env); - UccTeam_h team = job.create_team(n_procs); - int repeat = 3; - UccCollCtxVec ctxs; - std::vector mt = {UCC_MEMORY_TYPE_HOST}; +ucc_job_env_t post_op_env = {{"UCC_TL_UCP_REDUCE_AVG_PRE_OP", "0"}}; +ucc_job_env_t reduce_dbt_env = {{"UCC_TL_UCP_TUNE", "reduce:@dbt:0-inf:inf"}, + {"UCC_CLS", "basic"}}; - if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA)) { - mt.push_back(UCC_MEMORY_TYPE_CUDA); - } - if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA_MANAGED)) { - mt.push_back(UCC_MEMORY_TYPE_CUDA_MANAGED); - } +TYPED_TEST(test_reduce_avg_order, avg_post_op) { + TEST_DECLARE_WITH_ENV(post_op_env, 15); +} - for (auto count : {4, 256, 65536}) { - for (auto inplace : {TEST_NO_INPLACE, TEST_INPLACE}) { - for (auto m : mt) { - CHECK_TYPE_OP_SKIP(TypeParam::dt, TypeParam::redop, m); - SET_MEM_TYPE(m); - this->set_inplace(inplace); - this->data_init(n_procs, TypeParam::dt, count, ctxs, true); - UccReq req(team, ctxs); - CHECK_REQ_NOT_SUPPORTED_SKIP(req, this->data_fini(ctxs)); - for (auto i = 0; i < repeat; i++) { - req.start(); - req.wait(); - EXPECT_EQ(true, this->data_validate(ctxs)); - this->reset(ctxs); - } - this->data_fini(ctxs); - } - } - } +TYPED_TEST(test_reduce_dbt, reduce_dbt_shift) { + TEST_DECLARE_WITH_ENV(reduce_dbt_env, 15); +} + +TYPED_TEST(test_reduce_dbt, reduce_dbt_mirror) { + TEST_DECLARE_WITH_ENV(reduce_dbt_env, 16); }