From a036a5f1bc750036c61de1af3311e2fef5866a44 Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Mon, 8 May 2023 16:28:24 +0400 Subject: [PATCH] TL/UCP: add bruck alltoall (#756) * TL/UCP: add bruck alltoall * REVIEW: fix review comments * REVIEW: fix review comments --- src/Makefile.am | 1 + src/coll_patterns/bruck_alltoall.h | 41 ++++ src/components/tl/ucp/Makefile.am | 3 +- src/components/tl/ucp/alltoall/alltoall.c | 21 +- src/components/tl/ucp/alltoall/alltoall.h | 23 +- .../tl/ucp/alltoall/alltoall_bruck.c | 230 ++++++++++++++++++ .../tl/ucp/alltoallv/alltoallv_hybrid.c | 30 +-- src/components/tl/ucp/tl_ucp_coll.c | 71 +++++- src/components/tl/ucp/tl_ucp_coll.h | 20 +- src/components/tl/ucp/tl_ucp_team.c | 10 +- src/utils/ucc_math.h | 13 + 11 files changed, 413 insertions(+), 50 deletions(-) create mode 100644 src/coll_patterns/bruck_alltoall.h create mode 100644 src/components/tl/ucp/alltoall/alltoall_bruck.c diff --git a/src/Makefile.am b/src/Makefile.am index ed62abd7df..b3fe5ed1c2 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -89,6 +89,7 @@ noinst_HEADERS = \ components/ec/ucc_ec_log.h \ coll_patterns/recursive_knomial.h \ coll_patterns/sra_knomial.h \ + coll_patterns/bruck_alltoall.h \ components/topo/ucc_topo.h \ components/topo/ucc_sbgp.h diff --git a/src/coll_patterns/bruck_alltoall.h b/src/coll_patterns/bruck_alltoall.h new file mode 100644 index 0000000000..e4997a7910 --- /dev/null +++ b/src/coll_patterns/bruck_alltoall.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#ifndef BRUCK_ALLTOALL_H_ +#define BRUCK_ALLTOALL_H_ + +#include "utils/ucc_math.h" + +#define GET_NEXT_BRUCK_NUM(_num, _radix, _pow) \ + ((((_num) + 1) % (_pow))?((_num) + 1):(((_num) + 1) + (_pow) * ((_radix) - 1))) + +#define GET_PREV_BRUCK_NUM(_num, _radix, _pow) \ + (((_num) % (_pow))?((_num) - 1):(((_num) - 1) - (_pow) * ((_radix) - 1))) + +static inline ucc_rank_t get_bruck_step_start(uint32_t pow, uint32_t d) +{ + return pow * d; +} + +static inline ucc_rank_t get_bruck_step_finish(ucc_rank_t n, uint32_t radix, + uint32_t d, uint32_t pow) +{ + return ucc_min(n + pow - 1 - (n - d * pow) % (pow * radix), n); +} + +static inline ucc_rank_t get_bruck_recv_peer(ucc_rank_t trank, ucc_rank_t tsize, + ucc_rank_t step, uint32_t digit) +{ + return (trank - step * digit + tsize * digit) % tsize; +} + +static inline ucc_rank_t get_bruck_send_peer(ucc_rank_t trank, ucc_rank_t tsize, + ucc_rank_t step, uint32_t digit) +{ + return (trank + step * digit) % tsize; +} + +#endif diff --git a/src/components/tl/ucp/Makefile.am b/src/components/tl/ucp/Makefile.am index e8856dee9c..b7f93b881e 100644 --- a/src/components/tl/ucp/Makefile.am +++ b/src/components/tl/ucp/Makefile.am @@ -24,7 +24,8 @@ alltoall = \ alltoall/alltoall.h \ alltoall/alltoall.c \ alltoall/alltoall_onesided.c \ - alltoall/alltoall_pairwise.c + alltoall/alltoall_pairwise.c \ + alltoall/alltoall_bruck.c alltoallv = \ alltoallv/alltoallv.h \ diff --git a/src/components/tl/ucp/alltoall/alltoall.c b/src/components/tl/ucp/alltoall/alltoall.c index b9656a7cb1..faa888dcc0 100644 --- a/src/components/tl/ucp/alltoall/alltoall.c +++ b/src/components/tl/ucp/alltoall/alltoall.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -8,18 +8,37 @@ #include "tl_ucp.h" #include "alltoall.h" +#define ALLTOALL_MAX_PATTERN_SIZE (sizeof(UCC_TL_UCP_ALLTOALL_DEFAULT_ALG_SELECT_STR_PATTERN) + 32) +#define ALLTOALL_DEFAULT_ALG_SWITCH 129 + ucc_status_t ucc_tl_ucp_alltoall_pairwise_start(ucc_coll_task_t *task); void ucc_tl_ucp_alltoall_pairwise_progress(ucc_coll_task_t *task); ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *task); void ucc_tl_ucp_alltoall_onesided_progress(ucc_coll_task_t *task); +char* ucc_tl_ucp_alltoall_score_str_get(ucc_tl_ucp_team_t *team) +{ + int max_size = ALLTOALL_MAX_PATTERN_SIZE; + char *str; + + str = ucc_malloc(max_size * sizeof(char)); + ucc_snprintf_safe(str, max_size, + UCC_TL_UCP_ALLTOALL_DEFAULT_ALG_SELECT_STR_PATTERN, + ALLTOALL_DEFAULT_ALG_SWITCH * UCC_TL_TEAM_SIZE(team)); + return str; +} + ucc_base_coll_alg_info_t ucc_tl_ucp_alltoall_algs[UCC_TL_UCP_ALLTOALL_ALG_LAST + 1] = { [UCC_TL_UCP_ALLTOALL_ALG_PAIRWISE] = {.id = UCC_TL_UCP_ALLTOALL_ALG_PAIRWISE, .name = "pairwise", .desc = "pairwise two-sided implementation"}, + [UCC_TL_UCP_ALLTOALL_ALG_BRUCK] = + {.id = UCC_TL_UCP_ALLTOALL_ALG_BRUCK, + .name = "bruck", + .desc = "Bruck alltoall"}, [UCC_TL_UCP_ALLTOALL_ALG_ONESIDED] = {.id = UCC_TL_UCP_ALLTOALL_ALG_ONESIDED, .name = "onesided", diff --git a/src/components/tl/ucp/alltoall/alltoall.h b/src/components/tl/ucp/alltoall/alltoall.h index db8c28e615..746f3fcc47 100644 --- a/src/components/tl/ucp/alltoall/alltoall.h +++ b/src/components/tl/ucp/alltoall/alltoall.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -12,6 +12,7 @@ enum { UCC_TL_UCP_ALLTOALL_ALG_PAIRWISE, + UCC_TL_UCP_ALLTOALL_ALG_BRUCK, UCC_TL_UCP_ALLTOALL_ALG_ONESIDED, UCC_TL_UCP_ALLTOALL_ALG_LAST }; @@ -19,19 +20,27 @@ enum { extern ucc_base_coll_alg_info_t ucc_tl_ucp_alltoall_algs[UCC_TL_UCP_ALLTOALL_ALG_LAST + 1]; -#define UCC_TL_UCP_ALLTOALL_DEFAULT_ALG_SELECT_STR "alltoall:0-inf:@0" +#define UCC_TL_UCP_ALLTOALL_DEFAULT_ALG_SELECT_STR_PATTERN \ +"alltoall:host:0-%d:@bruck" + +char* ucc_tl_ucp_alltoall_score_str_get(ucc_tl_ucp_team_t *team); ucc_status_t ucc_tl_ucp_alltoall_init(ucc_tl_ucp_task_t *task); +ucc_status_t ucc_tl_ucp_alltoall_pairwise_init_common(ucc_tl_ucp_task_t *task); + ucc_status_t ucc_tl_ucp_alltoall_pairwise_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task_h); + ucc_base_team_t *team, + ucc_coll_task_t **task_h); + +ucc_status_t ucc_tl_ucp_alltoall_bruck_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h); -ucc_status_t ucc_tl_ucp_alltoall_pairwise_init_common(ucc_tl_ucp_task_t *task); ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t * team, - ucc_coll_task_t ** task_h); + ucc_base_team_t *team, + ucc_coll_task_t **task_h); #define ALLTOALL_CHECK_INPLACE(_args, _team) \ do { \ diff --git a/src/components/tl/ucp/alltoall/alltoall_bruck.c b/src/components/tl/ucp/alltoall/alltoall_bruck.c new file mode 100644 index 0000000000..5c30672381 --- /dev/null +++ b/src/components/tl/ucp/alltoall/alltoall_bruck.c @@ -0,0 +1,230 @@ +/** + * Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "config.h" +#include "tl_ucp.h" +#include "alltoall.h" +#include "tl_ucp_sendrecv.h" +#include "components/mc/ucc_mc.h" +#include "coll_patterns/bruck_alltoall.h" + +#define RADIX 2 + +enum { + PHASE_MERGE, + PHASE_SENDRECV +}; + +static inline int msb_pos_for_level(unsigned int nthbit, ucc_rank_t number) +{ + + int msb_set = -1; + unsigned int i; + + for (i = 0; i < nthbit - 1; i++) { + if (1 & number >> i) { + msb_set = i; + } + } + + return msb_set; +} + +static inline int find_seg_index(ucc_rank_t seg_index, int level, int nsegs_per_rblock) +{ + int block, blockseg; + + if (0 == seg_index) { + return -1; + } + + block = msb_pos_for_level(level, seg_index); + + if (block < 0) { + return -1; + } + + /* remove block bit from seg_index */ + blockseg = ((seg_index >> (block + 1)) << block) | + (seg_index & UCC_MASK(block)); + return block * nsegs_per_rblock + blockseg; +} + +ucc_status_t ucc_tl_ucp_alltoall_bruck_backward_rotation(void *dst, void *src, + ucc_rank_t trank, + ucc_rank_t tsize, + size_t seg_size) +{ + ucc_status_t st; + ucc_rank_t index, level, nsegs_per_rblock; + size_t snd_offset; + int send_buffer_index; + + level = lognum(tsize); + nsegs_per_rblock = tsize / 2; + for (index = 1; index < tsize; index++) { + send_buffer_index = find_seg_index(index, level + 1, nsegs_per_rblock); + ucc_assert(send_buffer_index >= 0); + snd_offset = send_buffer_index * seg_size; + st = ucc_mc_memcpy(PTR_OFFSET(dst, seg_size * + ((trank - index + tsize) % tsize)), + PTR_OFFSET(src, snd_offset), seg_size, + UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_HOST); + if (ucc_unlikely(st != UCC_OK)) { + return st; + } + } + + return UCC_OK; +} + +ucc_status_t ucc_tl_ucp_alltoall_bruck_finalize(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_status_t st, global_st; + + global_st = ucc_mc_free(task->alltoall_bruck.scratch_mc_header); + if (ucc_unlikely(global_st != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), "failed to free scratch buffer"); + } + + st = ucc_tl_ucp_coll_finalize(&task->super); + if (ucc_unlikely(st != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), "failed finalize collective"); + global_st = st; + } + return global_st; +} + +void ucc_tl_ucp_alltoall_bruck_progress(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_rank_t trank = UCC_TL_TEAM_RANK(team); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); + ucc_coll_args_t *args = &TASK_ARGS(task); + void *scratch = task->alltoall_bruck.scratch_mc_header->addr; + void *mergebuf = args->dst.info.buffer; + const ucc_rank_t nrecv_segs = tsize / 2; + const size_t seg_size = ucc_dt_size(args->src.info.datatype) * + args->src.info.count / tsize; + void *data; + ucc_rank_t sendto, recvfrom, step, index; + ucc_rank_t level, snd_count; + int send_buffer_index; + ucc_status_t st; + + if (task->alltoall_bruck.phase == PHASE_SENDRECV) { + goto ALLTOALL_BRUCK_PHASE_SENDRECV; + } + + step = 1 << (task->alltoall_bruck.iteration - 1); + while (step < tsize) { + level = task->alltoall_bruck.iteration - 1; + sendto = get_bruck_send_peer(trank, tsize, step, 1); + recvfrom = get_bruck_recv_peer(trank, tsize, step, 1); + + snd_count = 0; + for (index = get_bruck_step_start(step, 1); + index <= get_bruck_step_finish(tsize - 1, RADIX, 1, step); + index = GET_NEXT_BRUCK_NUM(index, RADIX, step)) { + send_buffer_index = find_seg_index(index, level + 1, nrecv_segs); + if (send_buffer_index == -1) { + data = PTR_OFFSET(args->src.info.buffer, + ((index + trank) % tsize) * seg_size); + } else { + data = PTR_OFFSET(scratch, send_buffer_index * seg_size); + } + ucc_mc_memcpy(PTR_OFFSET(mergebuf, seg_size * snd_count), + data, seg_size, UCC_MEMORY_TYPE_HOST, + UCC_MEMORY_TYPE_HOST); + snd_count++; + } + data = PTR_OFFSET(scratch, level * nrecv_segs * seg_size); + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(mergebuf, snd_count * seg_size, + UCC_MEMORY_TYPE_HOST, sendto, team, + task), + task, out); + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(data, snd_count * seg_size, + UCC_MEMORY_TYPE_HOST, recvfrom, team, + task), + task, out); +ALLTOALL_BRUCK_PHASE_SENDRECV: + if (ucc_tl_ucp_test(task) == UCC_INPROGRESS) { + task->alltoall_bruck.phase = PHASE_SENDRECV; + return; + } + task->alltoall_bruck.iteration++; + step = 1 << (task->alltoall_bruck.iteration - 1); + } + + st = ucc_mc_memcpy(PTR_OFFSET(args->dst.info.buffer, trank * seg_size), + PTR_OFFSET(args->src.info.buffer, trank * seg_size), + seg_size, UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_HOST); + if (ucc_unlikely(st != UCC_OK)) { + task->super.status = st; + return; + } + task->super.status = + ucc_tl_ucp_alltoall_bruck_backward_rotation(args->dst.info.buffer, + scratch, trank, tsize, + seg_size); +out: + return; +} + +ucc_status_t ucc_tl_ucp_alltoall_bruck_start(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + + task->alltoall_bruck.iteration = 1; + task->alltoall_bruck.phase = PHASE_MERGE; + ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); +} + +ucc_status_t ucc_tl_ucp_alltoall_bruck_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h) +{ + ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(tl_team); + ucc_coll_args_t *args = &coll_args->args; + size_t seg_size = ucc_dt_size(args->src.info.datatype) * + args->src.info.count / tsize; + size_t scratch_size; + ucc_tl_ucp_task_t *task; + ucc_status_t status; + + if ((coll_args->args.src.info.mem_type != UCC_MEMORY_TYPE_HOST) || + (coll_args->args.dst.info.mem_type != UCC_MEMORY_TYPE_HOST)) { + status = UCC_ERR_NOT_SUPPORTED; + goto out; + } + ALLTOALL_TASK_CHECK(coll_args->args, tl_team); + + task = ucc_tl_ucp_init_task(coll_args, team); + task->super.post = ucc_tl_ucp_alltoall_bruck_start; + task->super.progress = ucc_tl_ucp_alltoall_bruck_progress; + task->super.finalize = ucc_tl_ucp_alltoall_bruck_finalize; + + scratch_size = lognum(tsize) * ucc_div_round_up(tsize, 2) * seg_size; + status = ucc_mc_alloc(&task->alltoall_bruck.scratch_mc_header, + scratch_size, UCC_MEMORY_TYPE_HOST); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), "failed to allocate scratch buffer"); + ucc_tl_ucp_coll_finalize(&task->super); + return status; + } + + *task_h = &task->super; + return UCC_OK; + +out: + return status; +} diff --git a/src/components/tl/ucp/alltoallv/alltoallv_hybrid.c b/src/components/tl/ucp/alltoallv/alltoallv_hybrid.c index a65f19f13b..41289ba79d 100644 --- a/src/components/tl/ucp/alltoallv/alltoallv_hybrid.c +++ b/src/components/tl/ucp/alltoallv/alltoallv_hybrid.c @@ -12,6 +12,7 @@ #include "utils/ucc_coll_utils.h" #include "tl_ucp_sendrecv.h" #include "components/mc/ucc_mc.h" +#include "coll_patterns/bruck_alltoall.h" /* scratch structure @@ -93,35 +94,6 @@ typedef struct ucc_tl_ucp_alltoallv_hybrid_buf_meta { #define SET_BRUCK_DIGIT(_seg, _digit) \ ((_seg) = (((_digit) << ALLTOALLV_HYBRID_SEG_DIGIT) + ((_seg) & UCC_MASK(ALLTOALLV_HYBRID_SEG_DIGIT)))) -#define GET_NEXT_BRUCK_NUM(_num, _radix, _pow) \ - ((((_num) + 1) % (_pow))?((_num) + 1):(((_num) + 1) + (_pow) * ((_radix) - 1))) - -#define GET_PREV_BRUCK_NUM(_num, _radix, _pow) \ - (((_num) % (_pow))?((_num) - 1):(((_num) - 1) - (_pow) * ((_radix) - 1))) - -static inline ucc_rank_t get_bruck_step_start(uint32_t pow, uint32_t d) -{ - return pow * d; -} - -static inline ucc_rank_t get_bruck_step_finish(ucc_rank_t n, uint32_t radix, - uint32_t d, uint32_t pow) -{ - return ucc_min(n + pow - 1 - (n - d * pow) % (pow * radix), n); -} - -static inline ucc_rank_t get_bruck_recv_peer(ucc_rank_t trank, ucc_rank_t tsize, - ucc_rank_t step, uint32_t digit) -{ - return (trank - step * digit + tsize * digit) % tsize; -} - -static inline ucc_rank_t get_bruck_send_peer(ucc_rank_t trank, ucc_rank_t tsize, - ucc_rank_t step, uint32_t digit) -{ - return (trank + step * digit) % tsize; -} - static inline ucc_rank_t get_pairwise_send_peer(ucc_rank_t trank, ucc_rank_t tsize, ucc_rank_t step) { diff --git a/src/components/tl/ucp/tl_ucp_coll.c b/src/components/tl/ucp/tl_ucp_coll.c index 5f3c39611a..3306d65c42 100644 --- a/src/components/tl/ucp/tl_ucp_coll.c +++ b/src/components/tl/ucp/tl_ucp_coll.c @@ -24,14 +24,66 @@ #include "fanout/fanout.h" #include "scatterv/scatterv.h" -const char - *ucc_tl_ucp_default_alg_select_str[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR] = { - UCC_TL_UCP_ALLGATHER_DEFAULT_ALG_SELECT_STR, - UCC_TL_UCP_ALLREDUCE_DEFAULT_ALG_SELECT_STR, - UCC_TL_UCP_BCAST_DEFAULT_ALG_SELECT_STR, - UCC_TL_UCP_ALLTOALL_DEFAULT_ALG_SELECT_STR, - UCC_TL_UCP_REDUCE_SCATTER_DEFAULT_ALG_SELECT_STR, - UCC_TL_UCP_REDUCE_SCATTERV_DEFAULT_ALG_SELECT_STR}; +const ucc_tl_ucp_default_alg_desc_t + ucc_tl_ucp_default_alg_descs[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR] = { + { + .select_str = UCC_TL_UCP_ALLGATHER_DEFAULT_ALG_SELECT_STR, + .str_get_fn = NULL + }, + { + .select_str = NULL, + .str_get_fn = ucc_tl_ucp_alltoall_score_str_get + }, + { + .select_str = UCC_TL_UCP_ALLREDUCE_DEFAULT_ALG_SELECT_STR, + .str_get_fn = NULL + }, + { + .select_str = UCC_TL_UCP_BCAST_DEFAULT_ALG_SELECT_STR, + .str_get_fn = NULL + }, + { + .select_str = UCC_TL_UCP_REDUCE_SCATTER_DEFAULT_ALG_SELECT_STR, + .str_get_fn = NULL + }, + { + .select_str = UCC_TL_UCP_REDUCE_SCATTERV_DEFAULT_ALG_SELECT_STR, + .str_get_fn = NULL + }, +}; + +ucc_status_t ucc_tl_ucp_team_default_score_str_alloc(ucc_tl_ucp_team_t *team, + char *default_select_str[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR]) +{ + ucc_status_t st = UCC_OK; + int i; + + for (i = 0; i < UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR; i++) { + if (ucc_tl_ucp_default_alg_descs[i].select_str) { + default_select_str[i] = strdup(ucc_tl_ucp_default_alg_descs[i].select_str); + } else { + default_select_str[i] = ucc_tl_ucp_default_alg_descs[i].str_get_fn(team); + } + if (!default_select_str[i]) { + st = UCC_ERR_NO_MEMORY; + goto exit; + } + + } + +exit: + return st; +} + +void ucc_tl_ucp_team_default_score_str_free( + char *default_select_str[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR]) +{ + int i; + + for (i = 0; i < UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR; i++) { + ucc_free(default_select_str[i]); + } +} void ucc_tl_ucp_send_completion_cb(void *request, ucs_status_t status, void *user_data) @@ -232,6 +284,9 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str, case UCC_TL_UCP_ALLTOALL_ALG_PAIRWISE: *init = ucc_tl_ucp_alltoall_pairwise_init; break; + case UCC_TL_UCP_ALLTOALL_ALG_BRUCK: + *init = ucc_tl_ucp_alltoall_bruck_init; + break; case UCC_TL_UCP_ALLTOALL_ALG_ONESIDED: *init = ucc_tl_ucp_alltoall_onesided_init; break; diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index b0c12376e8..5f444b32ac 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -17,8 +17,12 @@ #define UCC_UUNITS_AUTO_RADIX 4 #define UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR 6 -extern const char - *ucc_tl_ucp_default_alg_select_str[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR]; + +ucc_status_t ucc_tl_ucp_team_default_score_str_alloc(ucc_tl_ucp_team_t *team, + char *default_select_str[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR]); + +void ucc_tl_ucp_team_default_score_str_free( + char *default_select_str[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR]); #define CALC_KN_TREE_DIST(_size, _radix, _dist) \ do { \ @@ -71,6 +75,11 @@ extern const char } \ } while (0) +typedef char* (*ucc_tl_ucp_score_str_get_fn_t)(ucc_tl_ucp_team_t *team); +typedef struct ucc_tl_ucp_default_alg_desc { + char *select_str; + ucc_tl_ucp_score_str_get_fn_t str_get_fn; +} ucc_tl_ucp_default_alg_desc_t; enum ucc_tl_ucp_task_flags { /*indicates whether subset field of tl_ucp_task is set*/ @@ -207,6 +216,11 @@ typedef struct ucc_tl_ucp_task { ucc_rank_t num2send; ucc_rank_t num2recv; } alltoallv_hybrid; + struct { + ucc_mc_buffer_header_t *scratch_mc_header; + ucc_rank_t iteration; + int phase; + } alltoall_bruck; }; } ucc_tl_ucp_task_t; @@ -346,7 +360,7 @@ static inline ucc_status_t ucc_tl_ucp_test(ucc_tl_ucp_task_t *task) return UCC_INPROGRESS; } -#define UCC_TL_UCP_TASK_RECV_COMPLETE(_task) \ +#define UCC_TL_UCP_TASK_RECV_COMPLETE(_task) \ (((_task)->tagged.recv_posted == (_task)->tagged.recv_completed)) static inline ucc_status_t ucc_tl_ucp_test_recv(ucc_tl_ucp_task_t *task) diff --git a/src/components/tl/ucp/tl_ucp_team.c b/src/components/tl/ucp/tl_ucp_team.c index 3640c99d01..97e9ad4da3 100644 --- a/src/components/tl/ucp/tl_ucp_team.c +++ b/src/components/tl/ucp/tl_ucp_team.c @@ -205,6 +205,8 @@ ucc_status_t ucc_tl_ucp_team_get_scores(ucc_base_team_t *tl_team, ucc_tl_coll_plugin_iface_t *tlcp; ucc_status_t status; unsigned i; + char *ucc_tl_ucp_default_alg_select_str + [UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR]; for (i = 0; i < UCC_MEMORY_TYPE_LAST; i++) { if (tl_ctx->ucp_memory_types & UCC_BIT(ucc_memtype_to_ucs[i])) { @@ -223,7 +225,11 @@ ucc_status_t ucc_tl_ucp_team_get_scores(ucc_base_team_t *tl_team, if (UCC_OK != status) { return status; } - + status = ucc_tl_ucp_team_default_score_str_alloc(team, + ucc_tl_ucp_default_alg_select_str); + if (UCC_OK != status) { + return status; + } for (i = 0; i < UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR; i++) { status = ucc_coll_score_update_from_str( ucc_tl_ucp_default_alg_select_str[i], score, UCC_TL_TEAM_SIZE(team), @@ -273,9 +279,11 @@ ucc_status_t ucc_tl_ucp_team_get_scores(ucc_base_team_t *tl_team, goto err; } } + ucc_tl_ucp_team_default_score_str_free(ucc_tl_ucp_default_alg_select_str); *score_p = score; return UCC_OK; err: + ucc_tl_ucp_team_default_score_str_free(ucc_tl_ucp_default_alg_select_str); ucc_coll_score_free(score); return status; } diff --git a/src/utils/ucc_math.h b/src/utils/ucc_math.h index ed0ff400b8..8720c9ae6b 100644 --- a/src/utils/ucc_math.h +++ b/src/utils/ucc_math.h @@ -105,4 +105,17 @@ static inline void float32tobfloat16(float float_val, void *bfloat16_ptr) #define ucc_align_up_pow2(_n, _alignment) \ ucc_align_down_pow2((_n) + (_alignment) - 1, _alignment) +/* compute the log2 of n, rounded up */ +static inline int lognum(int n) +{ + int count = 1; + int lognum = 0; + + while (count < n) { + count = count << 1; + lognum++; + } + return lognum; +} + #endif