From 28455f6ac8804ba66ea034812dbb60619d571b1c Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Fri, 8 Dec 2023 13:22:41 +0100 Subject: [PATCH] TL/UCP: use ring allgather when reordering needed (#879) --- src/components/tl/sharp/tl_sharp_context.c | 2 +- src/components/tl/ucp/allgather/allgather.c | 7 +++++++ src/components/tl/ucp/allgather/allgather_ring.c | 2 +- src/ucc/api/ucc.h | 2 +- src/utils/ucc_coll_utils.c | 12 ++++++++++++ src/utils/ucc_coll_utils.h | 2 ++ 6 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/components/tl/sharp/tl_sharp_context.c b/src/components/tl/sharp/tl_sharp_context.c index a0fbe61b79..72461066b3 100644 --- a/src/components/tl/sharp/tl_sharp_context.c +++ b/src/components/tl/sharp/tl_sharp_context.c @@ -434,7 +434,7 @@ ucc_status_t ucc_tl_sharp_context_create_epilog(ucc_base_context_t *context) if (lib->cfg.use_internal_oob) { sharp_ctx->oob_ctx.subset = set; } else { - sharp_ctx->oob_ctx.oob = &UCC_TL_CTX_OOB(sharp_ctx); + sharp_ctx->oob_ctx.oob = &UCC_TL_CTX_OOB(sharp_ctx); } status = ucc_topo_init(set, core_ctx->topo, &topo); diff --git a/src/components/tl/ucp/allgather/allgather.c b/src/components/tl/ucp/allgather/allgather.c index 90b06e99ee..926b732e55 100644 --- a/src/components/tl/ucp/allgather/allgather.c +++ b/src/components/tl/ucp/allgather/allgather.c @@ -38,7 +38,14 @@ char *ucc_tl_ucp_allgather_score_str_get(ucc_tl_ucp_team_t *team) ? UCC_TL_UCP_ALLGATHER_ALG_RING : UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR; char *str = ucc_malloc(max_size * sizeof(char)); + ucc_sbgp_t *sbgp; + if (team->cfg.use_reordering) { + sbgp = ucc_topo_get_sbgp(team->topo, UCC_SBGP_FULL_HOST_ORDERED); + if (!ucc_ep_map_is_identity(&sbgp->map)) { + algo_num = UCC_TL_UCP_ALLGATHER_ALG_RING; + } + } ucc_snprintf_safe(str, max_size, UCC_TL_UCP_ALLGATHER_DEFAULT_ALG_SELECT_STR, algo_num); return str; diff --git a/src/components/tl/ucp/allgather/allgather_ring.c b/src/components/tl/ucp/allgather/allgather_ring.c index 93d7b95fc4..07178aea25 100644 --- a/src/components/tl/ucp/allgather/allgather_ring.c +++ b/src/components/tl/ucp/allgather/allgather_ring.c @@ -108,7 +108,7 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task) ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task) { - ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); ucc_sbgp_t *sbgp; if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) { diff --git a/src/ucc/api/ucc.h b/src/ucc/api/ucc.h index 02e5e11540..a269dfb940 100644 --- a/src/ucc/api/ucc.h +++ b/src/ucc/api/ucc.h @@ -1337,7 +1337,7 @@ struct ucc_ep_map_cb { * @ingroup UCC_TEAM_DT */ typedef enum { - UCC_EP_MAP_FULL = 1, /*!< The ep range of the team spans all eps from a context. */ + UCC_EP_MAP_FULL = 1, /*!< The ep range of the team spans all eps from a context. */ UCC_EP_MAP_STRIDED = 2, /*!< The ep range of the team can be described by the 2 values: start, stride.*/ UCC_EP_MAP_ARRAY = 3, /*!< The ep range is given as an array of intergers that map the ep in the team to the team_context rank. */ diff --git a/src/utils/ucc_coll_utils.c b/src/utils/ucc_coll_utils.c index 7d449fadce..75a49400e2 100644 --- a/src/utils/ucc_coll_utils.c +++ b/src/utils/ucc_coll_utils.c @@ -644,6 +644,18 @@ ucc_ep_map_t ucc_ep_map_create_reverse(ucc_rank_t size) return map; } +int ucc_ep_map_is_identity(const ucc_ep_map_t *map) +{ + if ((map->type == UCC_EP_MAP_FULL) || + ((map->type == UCC_EP_MAP_STRIDED) && + (map->strided.start == 0) && + (map->strided.stride == 1))) { + return 1; + } else { + return 0; + } +} + static inline int ucc_ep_map_is_reverse(ucc_ep_map_t *map, int reversed_reordered_flag) { diff --git a/src/utils/ucc_coll_utils.h b/src/utils/ucc_coll_utils.h index ead7fe4081..ad7939837e 100644 --- a/src/utils/ucc_coll_utils.h +++ b/src/utils/ucc_coll_utils.h @@ -248,6 +248,8 @@ ucc_status_t ucc_ep_map_create_nested(ucc_ep_map_t *base_map, ucc_ep_map_t *sub_map, ucc_ep_map_t *out); +ucc_status_t ucc_ep_map_is_identity(const ucc_ep_map_t *map); + void ucc_ep_map_destroy_nested(ucc_ep_map_t *out); void ucc_ep_map_destroy(ucc_ep_map_t *map);