diff --git a/src/coll_patterns/double_binary_tree.h b/src/coll_patterns/double_binary_tree.h index 0040321f6f..edb3199264 100644 --- a/src/coll_patterns/double_binary_tree.h +++ b/src/coll_patterns/double_binary_tree.h @@ -18,6 +18,7 @@ typedef struct ucc_dbt_single_tree { ucc_rank_t root; ucc_rank_t parent; ucc_rank_t children[2]; + int n_children; int height; int recv; } ucc_dbt_single_tree_t; @@ -86,6 +87,19 @@ static inline void get_children(ucc_rank_t size, ucc_rank_t rank, int height, *r_c = get_right_child(size, rank, height, root); } +static inline int get_n_children(ucc_rank_t l_c, ucc_rank_t r_c) +{ + int n_children = 0; + + if (l_c != -1) { + n_children++; + } + if (r_c != -1) { + n_children++; + } + return n_children; +} + static inline int get_parent(int vsize, int vrank, int height, int troot) { if (vrank == troot) { @@ -118,6 +132,8 @@ static inline void ucc_dbt_build_t2_mirror(ucc_dbt_single_tree_t t1, size - 1 - t1.children[RIGHT_CHILD]; t.children[RIGHT_CHILD] = (t1.children[LEFT_CHILD] == -1) ? -1 : size - 1 - t1.children[LEFT_CHILD]; + t.n_children = get_n_children(t.children[LEFT_CHILD], + t.children[RIGHT_CHILD]); t.recv = 0; *t2 = t; @@ -138,6 +154,8 @@ static inline void ucc_dbt_build_t2_shift(ucc_dbt_single_tree_t t1, (t1.children[LEFT_CHILD] + 1) % size; t.children[RIGHT_CHILD] = (t1.children[RIGHT_CHILD] == -1) ? -1 : (t1.children[RIGHT_CHILD] + 1) % size; + t.n_children = get_n_children(t.children[LEFT_CHILD], + t.children[RIGHT_CHILD]); t.recv = 0; *t2 = t; @@ -152,12 +170,14 @@ static inline void ucc_dbt_build_t1(ucc_rank_t rank, ucc_rank_t size, get_children(size, rank, height, root, &t1->children[LEFT_CHILD], &t1->children[RIGHT_CHILD]); - t1->height = height; - t1->parent = parent; - t1->size = size; - t1->rank = rank; - t1->root = root; - t1->recv = 0; + t1->n_children = get_n_children(t1->children[LEFT_CHILD], + t1->children[RIGHT_CHILD]); + t1->height = height; + t1->parent = parent; + t1->size = size; + t1->rank = rank; + t1->root = root; + t1->recv = 0; } static inline ucc_rank_t ucc_dbt_convert_rank_for_shift(ucc_rank_t rank, diff --git a/src/components/tl/ucp/Makefile.am b/src/components/tl/ucp/Makefile.am index b35578aa9f..f5b90f211b 100644 --- a/src/components/tl/ucp/Makefile.am +++ b/src/components/tl/ucp/Makefile.am @@ -73,7 +73,8 @@ gatherv = \ reduce = \ reduce/reduce.h \ reduce/reduce.c \ - reduce/reduce_knomial.c + reduce/reduce_knomial.c \ + reduce/reduce_dbt.c reduce_scatter = \ reduce_scatter/reduce_scatter.h \ diff --git a/src/components/tl/ucp/reduce/reduce.c b/src/components/tl/ucp/reduce/reduce.c index 82a9380083..039f9f393b 100644 --- a/src/components/tl/ucp/reduce/reduce.c +++ b/src/components/tl/ucp/reduce/reduce.c @@ -13,6 +13,11 @@ ucc_base_coll_alg_info_t .name = "knomial", .desc = "reduce over knomial tree with arbitrary radix " "(optimized for latency)"}, + [UCC_TL_UCP_REDUCE_ALG_DBT] = + {.id = UCC_TL_UCP_REDUCE_ALG_DBT, + .name = "dbt", + .desc = "bcast over double binary tree where a leaf in one tree " + "will be intermediate in other (optimized for BW)"}, [UCC_TL_UCP_REDUCE_ALG_LAST] = { .id = 0, .name = NULL, .desc = NULL}}; @@ -66,3 +71,16 @@ ucc_status_t ucc_tl_ucp_reduce_init(ucc_tl_ucp_task_t *task) return status; } + +ucc_status_t ucc_tl_ucp_reduce_knomial_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h) +{ + ucc_tl_ucp_task_t *task; + ucc_status_t status; + + task = ucc_tl_ucp_init_task(coll_args, team); + status = ucc_tl_ucp_reduce_init(task); + *task_h = &task->super; + return status; +} diff --git a/src/components/tl/ucp/reduce/reduce.h b/src/components/tl/ucp/reduce/reduce.h index e26c4fdf23..98bc183ff3 100644 --- a/src/components/tl/ucp/reduce/reduce.h +++ b/src/components/tl/ucp/reduce/reduce.h @@ -9,12 +9,16 @@ enum { UCC_TL_UCP_REDUCE_ALG_KNOMIAL, + UCC_TL_UCP_REDUCE_ALG_DBT, UCC_TL_UCP_REDUCE_ALG_LAST }; extern ucc_base_coll_alg_info_t ucc_tl_ucp_reduce_algs[UCC_TL_UCP_REDUCE_ALG_LAST + 1]; +#define UCC_TL_UCP_REDUCE_DEFAULT_ALG_SELECT_STR \ + "reduce:0-inf:@0" + /* A set of convenience macros used to implement sw based progress of the reduce algorithm that uses kn pattern */ enum { @@ -36,12 +40,32 @@ enum { }; \ } while (0) + +static inline int ucc_tl_ucp_reduce_alg_from_str(const char *str) +{ + int i; + for (i = 0; i < UCC_TL_UCP_REDUCE_ALG_LAST; i++) { + if (0 == strcasecmp(str, ucc_tl_ucp_reduce_algs[i].name)) { + break; + } + } + return i; +} + ucc_status_t ucc_tl_ucp_reduce_init(ucc_tl_ucp_task_t *task); +ucc_status_t ucc_tl_ucp_reduce_knomial_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h); + ucc_status_t ucc_tl_ucp_reduce_knomial_start(ucc_coll_task_t *task); void ucc_tl_ucp_reduce_knomial_progress(ucc_coll_task_t *task); ucc_status_t ucc_tl_ucp_reduce_knomial_finalize(ucc_coll_task_t *task); +ucc_status_t ucc_tl_ucp_reduce_dbt_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h); + #endif diff --git a/src/components/tl/ucp/reduce/reduce_dbt.c b/src/components/tl/ucp/reduce/reduce_dbt.c new file mode 100644 index 0000000000..b3d1c13e7a --- /dev/null +++ b/src/components/tl/ucp/reduce/reduce_dbt.c @@ -0,0 +1,370 @@ +/** + * Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "config.h" +#include "tl_ucp.h" +#include "reduce.h" +#include "core/ucc_progress_queue.h" +#include "tl_ucp_sendrecv.h" +#include "utils/ucc_dt_reduce.h" + +enum { + RECV, + REDUCE, + TEST, + TEST_ROOT, +}; + +#define UCC_REDUCE_DBT_CHECK_STATE(_p) \ + case _p: \ + goto _p; + +#define UCC_REDUCE_DBT_GOTO_STATE(_state) \ + do { \ + switch (_state) { \ + UCC_REDUCE_DBT_CHECK_STATE(REDUCE); \ + UCC_REDUCE_DBT_CHECK_STATE(TEST); \ + UCC_REDUCE_DBT_CHECK_STATE(TEST_ROOT); \ + }; \ + } while (0) + +static void recv_completion_common(void *request, ucs_status_t status, + const ucp_tag_recv_info_t *info, /* NOLINT */ + void *user_data) +{ + ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *)user_data; + if (ucc_unlikely(UCS_OK != status)) { + tl_error(UCC_TASK_LIB(task), "failure in recv completion %s", + ucs_status_string(status)); + task->super.status = ucs_status_to_ucc_status(status); + } + task->tagged.recv_completed++; + if (request) { + ucp_request_free(request); + } +} + +static void recv_completion_1(void *request, ucs_status_t status, + const ucp_tag_recv_info_t *info, /* NOLINT */ + void *user_data) +{ + ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *)user_data; + + task->reduce_dbt.t1.recv++; + recv_completion_common(request, status, info, user_data); +} + +static void recv_completion_2(void *request, ucs_status_t status, + const ucp_tag_recv_info_t *info, /* NOLINT */ + void *user_data) +{ + ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *)user_data; + + task->reduce_dbt.t2.recv++; + recv_completion_common(request, status, info, user_data); +} + +static inline void single_tree_reduce(ucc_tl_ucp_task_t *task, void *sbuf, void *rbuf, int n_children, size_t count, size_t data_size, ucc_datatype_t dt, ucc_coll_args_t *args, int is_avg) +{ + ucc_status_t status; + + status = ucc_dt_reduce_strided( + sbuf,rbuf, rbuf, + n_children, count, data_size, + dt, args, + is_avg ? UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA : 0, + AVG_ALPHA(task), task->reduce_dbt.executor, + &task->reduce_dbt.etask); + + if (ucc_unlikely(UCC_OK != status)) { + tl_error(UCC_TASK_LIB(task), + "failed to perform dt reduction"); + task->super.status = status; + return; + } + EXEC_TASK_WAIT(task->reduce_dbt.etask); +} + +void ucc_tl_ucp_reduce_dbt_progress(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = + ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_rank_t rank = UCC_TL_TEAM_RANK(team); + ucc_dbt_single_tree_t t1 = task->reduce_dbt.t1; + ucc_dbt_single_tree_t t2 = task->reduce_dbt.t2; + ucc_memory_type_t mtype = args->src.info.mem_type; + ucc_datatype_t dt = args->src.info.datatype; + size_t count = args->src.info.count; + size_t count_t1 = (count % 2) ? count / 2 + 1 + : count / 2; + size_t data_size = count * ucc_dt_size(dt); + size_t data_size_t1 = count_t1 * ucc_dt_size(dt); + size_t data_size_t2 = count / 2 * ucc_dt_size(dt); + ucc_rank_t coll_root = (ucc_rank_t)args->root; + ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1, + recv_completion_2}; + int avg_pre_op = ((args->op == UCC_OP_AVG) && + UCC_TL_UCP_TEAM_LIB(TASK_TEAM(task))->cfg.reduce_avg_pre_op); + int avg_post_op = ((args->op == UCC_OP_AVG) && + !UCC_TL_UCP_TEAM_LIB(TASK_TEAM(task))->cfg.reduce_avg_pre_op); + void *t1_rbuf = task->reduce_dbt.scratch; + void *t2_rbuf = PTR_OFFSET(t1_rbuf, + data_size_t1 * 2); + void *t1_sbuf = avg_pre_op ? + PTR_OFFSET(t1_rbuf, + data_size * 2) : + args->src.info.buffer; + void *t2_sbuf = PTR_OFFSET(t1_sbuf, + data_size_t1); + uint32_t i, j; + + UCC_REDUCE_DBT_GOTO_STATE(task->reduce_dbt.state); + j = 0; + for (i = 0; i < 2; i++) { + if (t1.children[i] != -1) { + UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(t1_rbuf, + data_size_t1 * j), + data_size_t1, mtype, + t1.children[i], team, task, cb[0], + (void *)task), + task, out); + j++; + } + } + + j = 0; + for (i = 0; i < 2; i++) { + if (t2.children[i] != -1) { + UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(t2_rbuf, + data_size_t2 * j), + data_size_t2, mtype, + t2.children[i], team, task, cb[1], + (void *)task), + task, out); + j++; + } + } + task->reduce_dbt.state = REDUCE; + +REDUCE: + if (t1.recv == t1.n_children && !task->reduce_dbt.t1_reduction_comp) { + if (t1.n_children > 0) { + single_tree_reduce(task, t1_sbuf, t1_rbuf, t1.n_children, count_t1, + data_size_t1, dt, args, + avg_post_op && t1.root == rank); + } + task->reduce_dbt.t1_reduction_comp = 1; + } + if (t2.recv == t2.n_children && !task->reduce_dbt.t2_reduction_comp) { + if (t2.n_children > 0) { + single_tree_reduce(task, t2_sbuf, t2_rbuf, t2.n_children, + count / 2, data_size_t2, dt, args, + avg_post_op && t2.root == rank); + } + task->reduce_dbt.t2_reduction_comp = 1; + } + + if (rank != t1.root && task->reduce_dbt.t1_reduction_comp && + !task->reduce_dbt.t1_send_comp) { + UCPCHECK_GOTO(ucc_tl_ucp_send_nb((t1.n_children > 0) ? t1_rbuf + : t1_sbuf, + data_size_t1, mtype, t1.parent, team, + task), + task, out); + task->reduce_dbt.t1_send_comp = 1; + } + + if (rank != t2.root && task->reduce_dbt.t2_reduction_comp && + !task->reduce_dbt.t2_send_comp) { + UCPCHECK_GOTO(ucc_tl_ucp_send_nb((t2.n_children > 0) ? t2_rbuf + : t2_sbuf, + data_size_t2, mtype, t2.parent, team, + task), + task, out); + task->reduce_dbt.t2_send_comp = 1; + } + + if (!task->reduce_dbt.t1_reduction_comp || + !task->reduce_dbt.t2_reduction_comp) { + return; + } +TEST: + if (UCC_INPROGRESS == ucc_tl_ucp_test_send(task)) { + task->reduce_dbt.state = TEST; + return; + } + + if (rank == t1.root && rank != coll_root) { + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(t1_rbuf, data_size_t1, mtype, + coll_root, team, task), + task, out); + } + if (rank == t2.root && rank != coll_root) { + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(t2_rbuf, data_size_t2, mtype, + coll_root, team, task), + task, out); + } + + task->reduce_dbt.t1_reduction_comp = t1.recv; + task->reduce_dbt.t2_reduction_comp = t2.recv; + if (rank == coll_root && rank != t1.root) { + UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(args->dst.info.buffer, data_size_t1, + mtype, t1.root, team, task, cb[0], + (void *)task), + task, out); + task->reduce_dbt.t1_reduction_comp++; + } + if (rank == coll_root && rank != t2.root) { + UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(args->dst.info.buffer, + data_size_t1), + data_size_t2, mtype, t2.root, team, + task, cb[1], (void *)task), + task, out); + task->reduce_dbt.t2_reduction_comp++; + } +TEST_ROOT: + if (UCC_INPROGRESS == ucc_tl_ucp_test_send(task) || + task->reduce_dbt.t1_reduction_comp != t1.recv || + task->reduce_dbt.t2_reduction_comp != t2.recv) { + task->reduce_dbt.state = TEST_ROOT; + return; + } + if (rank == coll_root && rank == t1.root) { + UCPCHECK_GOTO(ucc_mc_memcpy(args->dst.info.buffer, t1_rbuf, + data_size_t1, mtype, mtype), task, out); + } + if (rank == coll_root && rank == t2.root) { + UCPCHECK_GOTO(ucc_mc_memcpy(PTR_OFFSET(args->dst.info.buffer, + data_size_t1), t2_rbuf, + data_size_t2, mtype, mtype), + task, out); + } + task->super.status = UCC_OK; + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_reduce_dbt_done", 0); + +out: + return; +} + +ucc_status_t ucc_tl_ucp_reduce_dbt_start(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, + ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_rank_t rank = UCC_TL_TEAM_RANK(team); + ucc_rank_t team_size = UCC_TL_TEAM_SIZE(team); + int avg_pre_op = + UCC_TL_UCP_TEAM_LIB(TASK_TEAM(task))->cfg.reduce_avg_pre_op; + ucc_datatype_t dt; + size_t count, data_size; + ucc_status_t status; + + task->reduce_dbt.t1.recv = 0; + task->reduce_dbt.t2.recv = 0; + task->reduce_dbt.t1_reduction_comp = 0; + task->reduce_dbt.t2_reduction_comp = 0; + task->reduce_dbt.t1_send_comp = 0; + task->reduce_dbt.t2_send_comp = 0; + ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + + if (TASK_ARGS(task).root == rank) { + count = TASK_ARGS(task).dst.info.count; + dt = TASK_ARGS(task).dst.info.datatype; + } else { + count = TASK_ARGS(task).src.info.count; + dt = TASK_ARGS(task).src.info.datatype; + } + data_size = count * ucc_dt_size(dt); + + status = ucc_coll_task_get_executor(&task->super, + &task->reduce_dbt.executor); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + + if (UCC_IS_INPLACE(*args) && (rank == args->root)) { + args->src.info.buffer = args->dst.info.buffer; + } + + if (avg_pre_op && args->op == UCC_OP_AVG) { + /* In case of avg_pre_op, each process must divide itself by team_size */ + status = + ucc_dt_reduce(args->src.info.buffer, args->src.info.buffer, + PTR_OFFSET(task->reduce_dbt.scratch, data_size * 2), + count, dt, args, UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA, + 1.0 / (double)(team_size * 2), + task->reduce_dbt.executor, &task->reduce_dbt.etask); + if (ucc_unlikely(UCC_OK != status)) { + tl_error(UCC_TASK_LIB(task), + "failed to perform dt reduction"); + return status; + } + EXEC_TASK_WAIT(task->reduce_dbt.etask, status); + } + + task->reduce_dbt.state = RECV; + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_reduce_dbt_start", 0); + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); +} + +ucc_status_t ucc_tl_ucp_reduce_dbt_finalize(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + + if (task->reduce_dbt.scratch_mc_header) { + ucc_mc_free(task->reduce_dbt.scratch_mc_header); + } + + return ucc_tl_ucp_coll_finalize(coll_task); +} + +ucc_status_t ucc_tl_ucp_reduce_dbt_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h) +{ + ucc_tl_ucp_team_t *tl_team; + ucc_tl_ucp_task_t *task; + ucc_rank_t rank, size; + ucc_memory_type_t mtype; + ucc_datatype_t dt; + size_t count; + size_t data_size; + ucc_status_t status; + + task = ucc_tl_ucp_init_task(coll_args, team); + task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; + task->super.post = ucc_tl_ucp_reduce_dbt_start; + task->super.progress = ucc_tl_ucp_reduce_dbt_progress; + task->super.finalize = ucc_tl_ucp_reduce_dbt_finalize; + tl_team = TASK_TEAM(task); + rank = UCC_TL_TEAM_RANK(tl_team); + size = UCC_TL_TEAM_SIZE(tl_team); + ucc_dbt_build_trees(rank, size, &task->reduce_dbt.t1, + &task->reduce_dbt.t2); + + if (coll_args->args.root == rank) { + count = coll_args->args.dst.info.count; + dt = coll_args->args.dst.info.datatype; + mtype = coll_args->args.dst.info.mem_type; + } else { + count = coll_args->args.src.info.count; + dt = coll_args->args.src.info.datatype; + mtype = coll_args->args.src.info.mem_type; + } + data_size = count * ucc_dt_size(dt); + task->reduce_dbt.scratch_mc_header = NULL; + status = ucc_mc_alloc(&task->reduce_dbt.scratch_mc_header, 3 * data_size, + mtype); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + task->reduce_dbt.scratch = task->reduce_dbt.scratch_mc_header->addr; + *task_h = &task->super; + return UCC_OK; +} diff --git a/src/components/tl/ucp/tl_ucp_coll.c b/src/components/tl/ucp/tl_ucp_coll.c index 872f064d16..280286b00c 100644 --- a/src/components/tl/ucp/tl_ucp_coll.c +++ b/src/components/tl/ucp/tl_ucp_coll.c @@ -42,6 +42,10 @@ const ucc_tl_ucp_default_alg_desc_t .select_str = UCC_TL_UCP_BCAST_DEFAULT_ALG_SELECT_STR, .str_get_fn = NULL }, + { + .select_str = UCC_TL_UCP_REDUCE_DEFAULT_ALG_SELECT_STR, + .str_get_fn = NULL + }, { .select_str = UCC_TL_UCP_REDUCE_SCATTER_DEFAULT_ALG_SELECT_STR, .str_get_fn = NULL @@ -223,6 +227,8 @@ static inline int alg_id_from_str(ucc_coll_type_t coll_type, const char *str) return ucc_tl_ucp_alltoallv_alg_from_str(str); case UCC_COLL_TYPE_BCAST: return ucc_tl_ucp_bcast_alg_from_str(str); + case UCC_COLL_TYPE_REDUCE: + return ucc_tl_ucp_reduce_alg_from_str(str); case UCC_COLL_TYPE_REDUCE_SCATTER: return ucc_tl_ucp_reduce_scatter_alg_from_str(str); case UCC_COLL_TYPE_REDUCE_SCATTERV: @@ -318,6 +324,19 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str, break; }; break; + case UCC_COLL_TYPE_REDUCE: + switch (alg_id) { + case UCC_TL_UCP_REDUCE_ALG_KNOMIAL: + *init = ucc_tl_ucp_reduce_knomial_init; + break; + case UCC_TL_UCP_REDUCE_ALG_DBT: + *init = ucc_tl_ucp_reduce_dbt_init; + break; + default: + status = UCC_ERR_INVALID_PARAM; + break; + }; + break; case UCC_COLL_TYPE_REDUCE_SCATTER: switch (alg_id) { case UCC_TL_UCP_REDUCE_SCATTER_ALG_RING: diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index 16b932e70b..e631813e00 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -17,7 +17,7 @@ #include "tl_ucp_tag.h" #define UCC_UUNITS_AUTO_RADIX 4 -#define UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR 7 +#define UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR 8 ucc_status_t ucc_tl_ucp_team_default_score_str_alloc(ucc_tl_ucp_team_t *team, char *default_select_str[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR]); @@ -200,6 +200,19 @@ typedef struct ucc_tl_ucp_task { ucc_ee_executor_task_t *etask; ucc_ee_executor_t *executor; } reduce_kn; + struct { + ucc_dbt_single_tree_t t1; + ucc_dbt_single_tree_t t2; + int state; + int t1_reduction_comp; + int t2_reduction_comp; + int t1_send_comp; + int t2_send_comp; + void *scratch; + ucc_mc_buffer_header_t *scratch_mc_header; + ucc_ee_executor_task_t *etask; + ucc_ee_executor_t *executor; + } reduce_dbt; struct { ucc_rank_t dist; ucc_rank_t max_dist; diff --git a/test/gtest/coll/test_reduce.cc b/test/gtest/coll/test_reduce.cc index 393e97decc..e2713ca6b5 100644 --- a/test/gtest/coll/test_reduce.cc +++ b/test/gtest/coll/test_reduce.cc @@ -282,42 +282,57 @@ TYPED_TEST(test_reduce_cuda, multiple_inplace_managed) { template class test_reduce_avg_order : public test_reduce { }; -TYPED_TEST_CASE(test_reduce_avg_order, CollReduceTypeOpsAvg); +template class test_reduce_dbt : public test_reduce { +}; -TYPED_TEST(test_reduce_avg_order, avg_post_op) -{ - int n_procs = 15; - ucc_job_env_t env = {{"UCC_TL_UCP_REDUCE_AVG_PRE_OP", "0"}}; - UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL, env); - UccTeam_h team = job.create_team(n_procs); - int repeat = 3; - UccCollCtxVec ctxs; - std::vector mt = {UCC_MEMORY_TYPE_HOST}; - - if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA)) { - mt.push_back(UCC_MEMORY_TYPE_CUDA); - } - if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA_MANAGED)) { - mt.push_back(UCC_MEMORY_TYPE_CUDA_MANAGED); +#define TEST_DECLARE_WITH_ENV(_env, _n_procs) \ + { \ + UccJob job(_n_procs, UccJob::UCC_JOB_CTX_GLOBAL, _env); \ + UccTeam_h team = job.create_team(_n_procs); \ + int repeat = 3; \ + UccCollCtxVec ctxs; \ + std::vector mt = {UCC_MEMORY_TYPE_HOST}; \ + if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA)) { \ + mt.push_back(UCC_MEMORY_TYPE_CUDA); \ + } \ + if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA_MANAGED)) { \ + mt.push_back(UCC_MEMORY_TYPE_CUDA_MANAGED); \ + } \ + for (auto count : {5, 256, 65536}) { \ + for (auto inplace : {TEST_NO_INPLACE, TEST_INPLACE}) { \ + for (auto m : mt) { \ + CHECK_TYPE_OP_SKIP(TypeParam::dt, TypeParam::redop, m); \ + SET_MEM_TYPE(m); \ + this->set_inplace(inplace); \ + this->data_init(_n_procs, TypeParam::dt, count, ctxs, true); \ + UccReq req(team, ctxs); \ + CHECK_REQ_NOT_SUPPORTED_SKIP(req, this->data_fini(ctxs)); \ + for (auto i = 0; i < repeat; i++) { \ + req.start(); \ + req.wait(); \ + EXPECT_EQ(true, this->data_validate(ctxs)); \ + this->reset(ctxs); \ + } \ + this->data_fini(ctxs); \ + } \ + } \ + } \ } - for (auto count : {4, 256, 65536}) { - for (auto inplace : {TEST_NO_INPLACE, TEST_INPLACE}) { - for (auto m : mt) { - CHECK_TYPE_OP_SKIP(TypeParam::dt, TypeParam::redop, m); - SET_MEM_TYPE(m); - this->set_inplace(inplace); - this->data_init(n_procs, TypeParam::dt, count, ctxs, true); - UccReq req(team, ctxs); - CHECK_REQ_NOT_SUPPORTED_SKIP(req, this->data_fini(ctxs)); - for (auto i = 0; i < repeat; i++) { - req.start(); - req.wait(); - EXPECT_EQ(true, this->data_validate(ctxs)); - this->reset(ctxs); - } - this->data_fini(ctxs); - } - } - } +TYPED_TEST_CASE(test_reduce_avg_order, CollReduceTypeOpsAvg); +TYPED_TEST_CASE(test_reduce_dbt, CollReduceTypeOpsHost); + +ucc_job_env_t post_op_env = {{"UCC_TL_UCP_REDUCE_AVG_PRE_OP", "0"}}; +ucc_job_env_t reduce_dbt_env = {{"UCC_TL_UCP_TUNE", "reduce:@dbt:0-inf:inf"}}; + +TYPED_TEST(test_reduce_avg_order, avg_post_op) { + TEST_DECLARE_WITH_ENV(post_op_env, 15); +} + +TYPED_TEST(test_reduce_dbt, reduce_dbt_shift) { + TEST_DECLARE_WITH_ENV(reduce_dbt_env, 15); +} + +TYPED_TEST(test_reduce_dbt, reduce_dbt_mirror) { + TEST_DECLARE_WITH_ENV(reduce_dbt_env, 16); }