From 6862aee4c365fcc2acf6ed101e3b0965c2775f2e Mon Sep 17 00:00:00 2001 From: Nick Sarkauskas Date: Fri, 26 Jan 2024 13:51:48 -0800 Subject: [PATCH] Urom cl support for sliding window allreduce, alltoall, and alltoallv --- src/components/cl/urom/allreduce/allreduce.c | 4 +- src/components/cl/urom/alltoall/alltoall.c | 69 +++ src/components/cl/urom/alltoall/alltoall.h | 4 +- src/components/cl/urom/alltoallv/alltoallv.c | 449 +++++++++---------- src/components/cl/urom/alltoallv/alltoallv.h | 22 +- src/components/cl/urom/cl_urom_coll.c | 6 + 6 files changed, 298 insertions(+), 256 deletions(-) diff --git a/src/components/cl/urom/allreduce/allreduce.c b/src/components/cl/urom/allreduce/allreduce.c index 6b7a2badc9..80546d4abd 100644 --- a/src/components/cl/urom/allreduce/allreduce.c +++ b/src/components/cl/urom/allreduce/allreduce.c @@ -5,7 +5,6 @@ */ #include "allreduce.h" -#include "../allreduce/allreduce.h" ucc_base_coll_alg_info_t ucc_cl_urom_allreduce_algs[UCC_CL_UROM_ALLREDUCE_ALG_LAST + 1] = { @@ -107,8 +106,7 @@ static ucc_status_t ucc_cl_urom_allreduce_full_start(ucc_coll_task_t *task) .ucc.cmd_type = UROM_WORKER_CMD_UCC_COLL, .ucc.coll_cmd.coll_args = coll_args, .ucc.coll_cmd.team = cl_team->teams[0], - //.ucc.coll_cmd.use_xgvmi = 0, - //.ucc.coll_cmd.use_sliding_window_allreduce = 1, + .ucc.coll_cmd.use_xgvmi = 1, }; ucc_memory_type_t prev_src, prev_dst; ucc_cl_urom_schedule_t *schedule = diff --git a/src/components/cl/urom/alltoall/alltoall.c b/src/components/cl/urom/alltoall/alltoall.c index 33d7d16755..45e30107f3 100644 --- a/src/components/cl/urom/alltoall/alltoall.c +++ b/src/components/cl/urom/alltoall/alltoall.c @@ -15,6 +15,45 @@ ucc_base_coll_alg_info_t [UCC_CL_UROM_ALLTOALL_ALG_LAST] = { .id = 0, .name = NULL, .desc = NULL}}; +static int buffer_export_ucc(ucp_context_h ucp_context, void *buf, size_t len, + struct export_buf *ebuf) +{ + ucs_status_t ucs_status; + ucp_mem_map_params_t params; + ucp_memh_pack_params_t pack_params; + + ebuf->ucp_context = ucp_context; + + params.field_mask = + UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH; + params.address = buf; + params.length = len; + + ucs_status = ucp_mem_map(ucp_context, ¶ms, &ebuf->memh); + assert(ucs_status == UCS_OK); + + pack_params.field_mask = UCP_MEMH_PACK_PARAM_FIELD_FLAGS; + pack_params.flags = UCP_MEMH_PACK_FLAG_EXPORT; + + ucs_status = ucp_memh_pack(ebuf->memh, &pack_params, &ebuf->packed_memh, + &ebuf->packed_memh_len); + if (ucs_status != UCS_OK) { + printf("ucp_memh_pack() returned error: %s\n", + ucs_status_string(ucs_status)); + ebuf->packed_memh = NULL; + ebuf->packed_memh_len = 0; + } + ucs_status = ucp_rkey_pack(ucp_context, ebuf->memh, &ebuf->packed_key, + &ebuf->packed_key_len); + if (UCS_OK != ucs_status) { + printf("ucp_rkey_pack() returned error: %s\n", + ucs_status_string(ucs_status)); + return UROM_ERR_NO_RESOURCE; + } + + return 0; +} + ucc_status_t ucc_cl_urom_alltoall_triggered_post_setup(ucc_coll_task_t *task) { return UCC_OK; @@ -58,6 +97,8 @@ static ucc_status_t ucc_cl_urom_alltoall_full_start(ucc_coll_task_t *task) ucc_cl_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, ucc_cl_urom_lib_t); ucc_coll_args_t *coll_args = &task->bargs.args; urom_status_t urom_status; + int ucp_index = cl_lib->tl_ucp_index; + ucc_tl_ucp_context_t *tl_ctx = ucc_derived_of(ctx->super.tl_ctxs[ucp_index], ucc_tl_ucp_context_t); urom_worker_cmd_t coll_cmd = { .cmd_type = UROM_WORKER_CMD_UCC, .ucc.dpu_worker_id = UCC_CL_TEAM_RANK(cl_team), @@ -65,13 +106,29 @@ static ucc_status_t ucc_cl_urom_alltoall_full_start(ucc_coll_task_t *task) .ucc.coll_cmd.coll_args = coll_args, .ucc.coll_cmd.team = cl_team->teams[0], //.ucc.coll_cmd.use_xgvmi = ctx->xgvmi_enabled, + .ucc.coll_cmd.use_xgvmi = 1, }; ucc_memory_type_t prev_src, prev_dst; + ucc_cl_urom_schedule_t *schedule = + ucc_derived_of(task, ucc_cl_urom_schedule_t); + struct export_buf *src_ebuf = &schedule->src_ebuf; + struct export_buf *dst_ebuf = &schedule->dst_ebuf; + prev_src = coll_args->src.info.mem_type; prev_dst = coll_args->dst.info.mem_type; coll_args->src.info.mem_type = UCC_MEMORY_TYPE_HOST; coll_args->dst.info.mem_type = UCC_MEMORY_TYPE_HOST; + + buffer_export_ucc(tl_ctx->worker.ucp_context, coll_args->src.info.buffer, coll_args->src.info.count * dt_size(coll_args->src.info.datatype), src_ebuf); + buffer_export_ucc(tl_ctx->worker.ucp_context, coll_args->dst.info.buffer, coll_args->dst.info.count * dt_size(coll_args->dst.info.datatype), dst_ebuf); + + coll_cmd.ucc.coll_cmd.src_memh_packed = src_ebuf->packed_memh; + coll_cmd.ucc.coll_cmd.src_memh_packed_len = src_ebuf->packed_memh_len; + + coll_cmd.ucc.coll_cmd.dst_memh_packed = dst_ebuf->packed_memh; + coll_cmd.ucc.coll_cmd.dst_memh_packed_len = dst_ebuf->packed_memh_len; + urom_status = urom_worker_push_cmdq(cl_lib->urom_ctx.urom_worker, 0, &coll_cmd); if (UROM_OK != urom_status) { cl_debug(&cl_lib->super, "failed to push collective to urom"); @@ -110,6 +167,18 @@ static ucc_status_t ucc_cl_urom_alltoall_full_finalize(ucc_coll_task_t *task) ucc_derived_of(task, ucc_cl_urom_schedule_t); ucc_status_t status; + ucc_cl_urom_team_t *cl_team = ucc_derived_of(task->team, ucc_cl_urom_team_t); + ucc_cl_urom_context_t *ctx = UCC_CL_UROM_TEAM_CTX(cl_team); + ucc_cl_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, ucc_cl_urom_lib_t); + int ucp_index = cl_lib->tl_ucp_index; + ucc_tl_ucp_context_t *tl_ctx = ucc_derived_of(ctx->super.tl_ctxs[ucp_index], ucc_tl_ucp_context_t); + + struct export_buf *src_ebuf = &schedule->src_ebuf; + struct export_buf *dst_ebuf = &schedule->dst_ebuf; + + ucp_mem_unmap(tl_ctx->worker.ucp_context, src_ebuf->memh); + ucp_mem_unmap(tl_ctx->worker.ucp_context, dst_ebuf->memh); + status = ucc_schedule_finalize(task); ucc_cl_urom_put_schedule(&schedule->super.super); return status; diff --git a/src/components/cl/urom/alltoall/alltoall.h b/src/components/cl/urom/alltoall/alltoall.h index 19743f7ed9..aade7cca54 100644 --- a/src/components/cl/urom/alltoall/alltoall.h +++ b/src/components/cl/urom/alltoall/alltoall.h @@ -7,6 +7,7 @@ #ifndef UROM_ALLTOALL_H_ #define UROM_ALLTOALL_H_ #include "../cl_urom_coll.h" +#include "../../../tl/ucp/tl_ucp.h" enum { @@ -17,9 +18,6 @@ enum extern ucc_base_coll_alg_info_t ucc_cl_urom_alltoall_algs[UCC_CL_UROM_ALLTOALL_ALG_LAST + 1]; -ucc_status_t ucc_cl_urom_alltoall_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t * team, - ucc_coll_task_t ** task); static inline int ucc_cl_urom_alltoall_alg_from_str(const char *str) { diff --git a/src/components/cl/urom/alltoallv/alltoallv.c b/src/components/cl/urom/alltoallv/alltoallv.c index 5715faa47c..093fc00aee 100644 --- a/src/components/cl/urom/alltoallv/alltoallv.c +++ b/src/components/cl/urom/alltoallv/alltoallv.c @@ -1,282 +1,255 @@ /** * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * Copyright (c) Meta Platforms, Inc. and affiliates. 2022. * * See file LICENSE for terms. */ #include "alltoallv.h" -#include "../cl_hier_coll.h" -#include "core/ucc_team.h" ucc_base_coll_alg_info_t - ucc_cl_hier_alltoallv_algs[UCC_CL_HIER_ALLTOALLV_ALG_LAST + 1] = { - [UCC_CL_HIER_ALLTOALLV_ALG_NODE_SPLIT] = - {.id = UCC_CL_HIER_ALLTOALLV_ALG_NODE_SPLIT, - .name = "node_split", - .desc = "splitting alltoallv into two concurrent a2av calls" - " withing the node and outside of it"}, - [UCC_CL_HIER_ALLTOALLV_ALG_LAST] = { + ucc_cl_urom_alltoallv_algs[UCC_CL_UROM_ALLTOALLV_ALG_LAST + 1] = { + [UCC_CL_UROM_ALLTOALLV_ALG_FULL] = + {.id = UCC_CL_UROM_ALLTOALLV_ALG_FULL, + .name = "urom_full_offload", + .desc = "full offload of alltoallv"}, + [UCC_CL_UROM_ALLTOALLV_ALG_LAST] = { .id = 0, .name = NULL, .desc = NULL}}; -static ucc_status_t ucc_cl_hier_alltoallv_start(ucc_coll_task_t *task) +static int buffer_export_ucc(ucp_context_h ucp_context, void *buf, size_t len, + struct export_buf *ebuf) { - UCC_CL_HIER_PROFILE_REQUEST_EVENT(task, "cl_hier_alltoallv_start", 0); - return ucc_schedule_start(task); + ucs_status_t ucs_status; + ucp_mem_map_params_t params; + ucp_memh_pack_params_t pack_params; + + ebuf->ucp_context = ucp_context; + + params.field_mask = + UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH; + params.address = buf; + params.length = len; + + ucs_status = ucp_mem_map(ucp_context, ¶ms, &ebuf->memh); + assert(ucs_status == UCS_OK); + + pack_params.field_mask = UCP_MEMH_PACK_PARAM_FIELD_FLAGS; + pack_params.flags = UCP_MEMH_PACK_FLAG_EXPORT; + + ucs_status = ucp_memh_pack(ebuf->memh, &pack_params, &ebuf->packed_memh, + &ebuf->packed_memh_len); + if (ucs_status != UCS_OK) { + printf("ucp_memh_pack() returned error: %s\n", + ucs_status_string(ucs_status)); + ebuf->packed_memh = NULL; + ebuf->packed_memh_len = 0; + } + ucs_status = ucp_rkey_pack(ucp_context, ebuf->memh, &ebuf->packed_key, + &ebuf->packed_key_len); + if (UCS_OK != ucs_status) { + printf("ucp_rkey_pack() returned error: %s\n", + ucs_status_string(ucs_status)); + return UROM_ERR_NO_RESOURCE; + } + + return 0; } -static ucc_status_t ucc_cl_hier_alltoallv_finalize(ucc_coll_task_t *task) +ucc_status_t ucc_cl_urom_alltoallv_triggered_post_setup(ucc_coll_task_t *task) { - ucc_cl_hier_schedule_t *schedule = - ucc_derived_of(task, ucc_cl_hier_schedule_t); - ucc_status_t status; + return UCC_OK; +} - UCC_CL_HIER_PROFILE_REQUEST_EVENT(task, "cl_hier_alltoallv_finalize", 0); - ucc_assert(schedule->super.super.n_tasks == 1 || - schedule->super.super.n_tasks == 2); - if (schedule->scratch) { - ucc_mc_free(schedule->scratch); +static size_t dt_size(ucc_datatype_t ucc_dt) +{ + size_t size_mod = 8; + + switch(ucc_dt) { + case UCC_DT_INT8: + case UCC_DT_UINT8: + size_mod = sizeof(char); + break; + case UCC_DT_INT32: + case UCC_DT_UINT32: + case UCC_DT_FLOAT32: + size_mod = sizeof(int); + break; + case UCC_DT_INT64: + case UCC_DT_UINT64: + case UCC_DT_FLOAT64: + size_mod = sizeof(uint64_t); + break; + case UCC_DT_INT128: + case UCC_DT_UINT128: + case UCC_DT_FLOAT128: + size_mod = sizeof(__int128_t); + break; + default: + break; } - status = ucc_schedule_finalize(task); - ucc_cl_hier_put_schedule(&schedule->super.super); - return status; + + return size_mod; } -#define SET_FULL_COUNTS(_type, _sbgp, _coll_args, _team, _node_thresh, \ - _sdt_size, _rdt_size, _sc_full, _sd_full, _rc_full, \ - _rd_full) \ - do { \ - int _i, _is_local; \ - _type _scount, _rcount; \ - \ - for (_i = 0; _i < (_sbgp)->group_size; _i++) { \ - _scount = ((_type *)(_coll_args)->args.src.info_v.counts)[_i]; \ - _rcount = ((_type *)(_coll_args)->args.dst.info_v.counts)[_i]; \ - _is_local = \ - ucc_rank_on_local_node(_i, (_team)->params.team->topo); \ - if ((_scount * _sdt_size > (_node_thresh)) && _is_local) { \ - ((_type *)_sc_full)[_i] = 0; \ - } else { \ - ((_type *)_sc_full)[_i] = _scount; \ - ((_type *)_sd_full)[_i] = \ - ((_type *)(_coll_args) \ - ->args.src.info_v.displacements)[_i]; \ - } \ - if ((_rcount * _rdt_size > (_node_thresh)) && _is_local) { \ - ((_type *)_rc_full)[_i] = 0; \ - } else { \ - ((_type *)_rc_full)[_i] = _rcount; \ - ((_type *)_rd_full)[_i] = \ - ((_type *)(_coll_args) \ - ->args.dst.info_v.displacements)[_i]; \ - } \ - } \ - } while (0) - -#define SET_NODE_COUNTS(_type, _sbgp, _coll_args, _node_thresh, _sdt_size, \ - _rdt_size, _sc_node, _sd_node, _rc_node, _rd_node) \ - do { \ - int _i; \ - _type _scount, _rcount; \ - \ - for (_i = 0; _i < (_sbgp)->group_size; _i++) { \ - ucc_rank_t r = ucc_ep_map_eval((_sbgp)->map, _i); \ - _scount = ((_type *)(_coll_args)->args.src.info_v.counts)[r]; \ - _rcount = ((_type *)(_coll_args)->args.dst.info_v.counts)[r]; \ - if (_scount * _sdt_size <= (_node_thresh)) { \ - ((_type *)_sc_node)[_i] = 0; \ - } else { \ - ((_type *)_sc_node)[_i] = _scount; \ - ((_type *)_sd_node)[_i] = \ - ((_type *)(_coll_args)->args.src.info_v.displacements)[r]; \ - } \ - if (_rcount * _rdt_size <= (_node_thresh)) { \ - ((_type *)_rc_node)[_i] = 0; \ - } else { \ - ((_type *)_rc_node)[_i] = _rcount; \ - ((_type *)_rd_node)[_i] = \ - ((_type *)(_coll_args)->args.dst.info_v.displacements)[r]; \ - } \ - } \ - } while (0) - -ucc_status_t ucc_cl_hier_alltoallv_triggered_post_setup(ucc_coll_task_t *task) +static ucc_status_t ucc_cl_urom_alltoallv_full_start(ucc_coll_task_t *task) { - ucc_cl_hier_schedule_t *schedule = - ucc_derived_of(task, ucc_cl_hier_schedule_t); - ucc_status_t status = UCC_OK; - int n_tasks = schedule->super.super.n_tasks; - int i = 0; - - for (i = 0; i < n_tasks; ++i) { - ucc_coll_task_t *sub_task = schedule->super.super.tasks[i]; - if (sub_task->triggered_post_setup != NULL) { - sub_task->ee = task->ee; - sub_task->triggered_post_setup(sub_task); + ucc_cl_urom_team_t *cl_team = ucc_derived_of(task->team, ucc_cl_urom_team_t); + ucc_cl_urom_context_t *ctx = UCC_CL_UROM_TEAM_CTX(cl_team); + ucc_cl_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, ucc_cl_urom_lib_t); + ucc_coll_args_t *coll_args = &task->bargs.args; + urom_status_t urom_status; + int ucp_index = cl_lib->tl_ucp_index; + ucc_tl_ucp_context_t *tl_ctx = ucc_derived_of(ctx->super.tl_ctxs[ucp_index], ucc_tl_ucp_context_t); + urom_worker_cmd_t coll_cmd = { + .cmd_type = UROM_WORKER_CMD_UCC, + .ucc.dpu_worker_id = UCC_CL_TEAM_RANK(cl_team), + .ucc.cmd_type = UROM_WORKER_CMD_UCC_COLL, + .ucc.coll_cmd.coll_args = coll_args, + .ucc.coll_cmd.team = cl_team->teams[0], + //.ucc.coll_cmd.use_xgvmi = ctx->xgvmi_enabled, + .ucc.coll_cmd.use_xgvmi = 1, + }; + ucc_memory_type_t prev_src, prev_dst; + int i; + + ucc_cl_urom_schedule_t *schedule = + ucc_derived_of(task, ucc_cl_urom_schedule_t); + struct export_buf *src_ebuf = &schedule->src_ebuf; + struct export_buf *dst_ebuf = &schedule->dst_ebuf; + + size_t src_count = 0; + size_t dst_count = 0; + size_t size_mod_src = dt_size(coll_args->src.info_v.datatype); + size_t size_mod_dst = dt_size(coll_args->dst.info_v.datatype); + + // get the total count of the src and dst bufs + for (i = 0; i < UCC_CL_TEAM_SIZE(cl_team); i++) { + if ((coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS) && + (coll_args->flags & UCC_COLL_ARGS_FLAG_COUNT_64BIT)) { + uint64_t count = coll_args->src.info_v.counts[i] * size_mod_src; + src_count += count; + count = coll_args->dst.info_v.counts[i] * size_mod_dst; + dst_count += count; + } else { + uint32_t count = coll_args->src.info_v.counts[i] * size_mod_src; + src_count += count; + count = coll_args->dst.info_v.counts[i] * size_mod_dst; + dst_count += count; } } + + prev_src = coll_args->src.info_v.mem_type; + prev_dst = coll_args->dst.info_v.mem_type; + coll_args->src.info_v.mem_type = UCC_MEMORY_TYPE_HOST; + coll_args->dst.info_v.mem_type = UCC_MEMORY_TYPE_HOST; + + buffer_export_ucc(tl_ctx->worker.ucp_context, coll_args->src.info_v.buffer, src_count * dt_size(coll_args->src.info_v.datatype), src_ebuf); + buffer_export_ucc(tl_ctx->worker.ucp_context, coll_args->dst.info_v.buffer, dst_count * dt_size(coll_args->dst.info_v.datatype), dst_ebuf); + + coll_cmd.ucc.coll_cmd.src_memh_packed = src_ebuf->packed_memh; + coll_cmd.ucc.coll_cmd.src_memh_packed_len = src_ebuf->packed_memh_len; + + coll_cmd.ucc.coll_cmd.dst_memh_packed = dst_ebuf->packed_memh; + coll_cmd.ucc.coll_cmd.dst_memh_packed_len = dst_ebuf->packed_memh_len; + + urom_status = urom_worker_push_cmdq(cl_lib->urom_ctx.urom_worker, 0, &coll_cmd); + if (UROM_OK != urom_status) { + cl_debug(&cl_lib->super, "failed to push collective to urom"); + return UCC_ERR_NO_MESSAGE; + } + coll_args->src.info_v.mem_type = prev_src; + coll_args->dst.info_v.mem_type = prev_dst; + + task->status = UCC_INPROGRESS; + cl_debug(&cl_lib->super, "pushed the collective to urom"); + return ucc_progress_queue_enqueue(ctx->super.super.ucc_context->pq, task); +} + +static ucc_status_t ucc_cl_urom_alltoallv_full_finalize(ucc_coll_task_t *task) +{ + ucc_cl_urom_schedule_t *schedule = + ucc_derived_of(task, ucc_cl_urom_schedule_t); + ucc_status_t status; + + ucc_cl_urom_team_t *cl_team = ucc_derived_of(task->team, ucc_cl_urom_team_t); + ucc_cl_urom_context_t *ctx = UCC_CL_UROM_TEAM_CTX(cl_team); + ucc_cl_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, ucc_cl_urom_lib_t); + int ucp_index = cl_lib->tl_ucp_index; + ucc_tl_ucp_context_t *tl_ctx = ucc_derived_of(ctx->super.tl_ctxs[ucp_index], ucc_tl_ucp_context_t); + + struct export_buf *src_ebuf = &schedule->src_ebuf; + struct export_buf *dst_ebuf = &schedule->dst_ebuf; + + ucp_mem_unmap(tl_ctx->worker.ucp_context, src_ebuf->memh); + ucp_mem_unmap(tl_ctx->worker.ucp_context, dst_ebuf->memh); + + status = ucc_schedule_finalize(task); + ucc_cl_urom_put_schedule(&schedule->super.super); return status; } -UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_alltoallv_init, - (coll_args, team, task), - ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, - ucc_coll_task_t **task) +static void ucc_cl_urom_alltoallv_full_progress(ucc_coll_task_t *ctask) { - ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t); - ucc_cl_hier_lib_t *cl_lib = UCC_CL_HIER_TEAM_LIB(cl_team); - int full_only = 0; - ucc_cl_hier_schedule_t *cl_schedule; - ucc_schedule_t *schedule; - ucc_status_t status; - ucc_base_coll_args_t args; - ucc_coll_task_t *task_node, *task_full; - int c64, d64; - void *sc_full, *sd_full, *rc_full, *rd_full; - void *sc_node, *sd_node, *rc_node, *rd_node; - ucc_rank_t full_size, node_size; - size_t sdt_size, rdt_size; - ucc_sbgp_t *sbgp; - size_t elem_size; - - if (UCC_IS_INPLACE(coll_args->args)) { - cl_debug(team->context->lib, "inplace alltoallv is not supported"); - return UCC_ERR_NOT_SUPPORTED; + ucc_cl_urom_team_t *cl_team = ucc_derived_of(ctask->team, ucc_cl_urom_team_t); + ucc_cl_urom_context_t *ctx = UCC_CL_UROM_TEAM_CTX(cl_team); + ucc_cl_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, ucc_cl_urom_lib_t); + urom_status_t urom_status = 0; + urom_worker_notify_t *notif; + + urom_status = urom_worker_pop_notifyq(cl_lib->urom_ctx.urom_worker, 0, ¬if); + if (UROM_ERR_QUEUE_EMPTY == urom_status) { + return; } - if (!SBGP_ENABLED(cl_team, FULL)) { - cl_debug(team->context->lib, "alltoallv requires FULL sbgp"); - return UCC_ERR_NOT_SUPPORTED; + if (urom_status < 0) { + cl_error(cl_lib, "Error in UROM"); + ctask->status = UCC_ERR_NO_MESSAGE; + return; } - c64 = UCC_COLL_ARGS_COUNT64(&coll_args->args); - d64 = UCC_COLL_ARGS_DISPL64(&coll_args->args); - - if (c64 ^ d64) { - cl_debug(team->context->lib, - "mixed 64 bit count/displ mode is not supported"); - return UCC_ERR_NOT_SUPPORTED; + if (notif->notify_type != UROM_WORKER_NOTIFY_UCC) { + cl_debug(cl_lib, "WRONG NOTIFICATION (%ld != %d)", notif->notify_type, UROM_WORKER_NOTIFY_UCC); + return; } - cl_schedule = ucc_cl_hier_get_schedule(cl_team); + cl_debug(&cl_lib->super, "completed the collective from urom"); + + ctask->status = (ucc_status_t) notif->ucc.status; +} + +ucc_status_t ucc_cl_urom_alltoallv_full_init( + ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, + ucc_coll_task_t **task) +{ + ucc_cl_urom_team_t *cl_team = ucc_derived_of(team, ucc_cl_urom_team_t); + ucc_cl_urom_context_t *ctx = UCC_CL_UROM_TEAM_CTX(cl_team); + ucc_cl_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, ucc_cl_urom_lib_t); + + ucc_cl_urom_schedule_t *cl_schedule; + ucc_base_coll_args_t args; + ucc_schedule_t *schedule; + ucc_status_t status; + + cl_schedule = ucc_cl_urom_get_schedule(cl_team); if (ucc_unlikely(!cl_schedule)) { return UCC_ERR_NO_MEMORY; } schedule = &cl_schedule->super.super; - memcpy(&args, coll_args, sizeof(args)); - UCC_CHECK_GOTO(ucc_schedule_init(schedule, &args, team), error, status); - - full_size = cl_team->sbgps[UCC_HIER_SBGP_FULL].sbgp->group_size; - node_size = cl_team->sbgps[UCC_HIER_SBGP_NODE].sbgp->group_size; - elem_size = c64 ? 8 : 4; - - if (!SBGP_ENABLED(cl_team, NODE)) { - full_only = 1; - UCC_CHECK_GOTO(ucc_coll_init( - cl_team->sbgps[UCC_HIER_SBGP_FULL].score_map, - &args, &task_full), - error, status); - goto full; - } - status = ucc_mc_alloc(&cl_schedule->scratch, - elem_size * (full_size + node_size) * 4, - UCC_MEMORY_TYPE_HOST); - if (ucc_unlikely(UCC_OK != status)) { - cl_error(team->context->lib, - "failed to allocate %zd bytes for full counts", - elem_size * (full_size + node_size) * 4); - goto error; - } - - sc_full = cl_schedule->scratch->addr; - sd_full = PTR_OFFSET(sc_full, full_size * elem_size); - rc_full = PTR_OFFSET(sc_full, full_size * elem_size * 2); - rd_full = PTR_OFFSET(sc_full, full_size * elem_size * 3); - - sc_node = PTR_OFFSET(sc_full, full_size * elem_size * 4); - sd_node = PTR_OFFSET(sc_node, node_size * elem_size); - rc_node = PTR_OFFSET(sc_node, node_size * elem_size * 2); - rd_node = PTR_OFFSET(sc_node, node_size * elem_size * 3); - - /* Duplicate FULL a2av info and alloc task */ - sbgp = cl_team->sbgps[UCC_HIER_SBGP_FULL].sbgp; - ucc_assert(sbgp->group_size == team->params.size); - sdt_size = ucc_dt_size(coll_args->args.src.info_v.datatype); - rdt_size = ucc_dt_size(coll_args->args.dst.info_v.datatype); - - if (c64) { - SET_FULL_COUNTS(uint64_t, sbgp, coll_args, team, - cl_lib->cfg.a2av_node_thresh, sdt_size, rdt_size, - sc_full, sd_full, rc_full, rd_full); - } else { - SET_FULL_COUNTS(uint32_t, sbgp, coll_args, team, - cl_lib->cfg.a2av_node_thresh, sdt_size, rdt_size, - sc_full, sd_full, rc_full, rd_full); - } - args.args.src.info_v.counts = (ucc_count_t *)sc_full; - args.args.dst.info_v.counts = (ucc_count_t *)rc_full; - args.args.src.info_v.displacements = (ucc_aint_t *)sd_full; - args.args.dst.info_v.displacements = (ucc_aint_t *)rd_full; - - UCC_CHECK_GOTO(ucc_coll_init(cl_team->sbgps[UCC_HIER_SBGP_FULL].score_map, - &args, &task_full), - err_init_1, status); - - /* Setup NODE a2av */ - sbgp = cl_team->sbgps[UCC_HIER_SBGP_NODE].sbgp; - - if (c64) { - SET_NODE_COUNTS(uint64_t, sbgp, coll_args, - cl_lib->cfg.a2av_node_thresh, sdt_size, rdt_size, - sc_node, sd_node, rc_node, rd_node); - } else { - SET_NODE_COUNTS(uint32_t, sbgp, coll_args, - cl_lib->cfg.a2av_node_thresh, sdt_size, rdt_size, - sc_node, sd_node, rc_node, rd_node); + memcpy(&args, coll_args, sizeof(args)); + status = ucc_schedule_init(schedule, &args, team); + if (UCC_OK != status) { + ucc_cl_urom_put_schedule(schedule); + return status; } - args.args.src.info_v.counts = (ucc_count_t *)sc_node; - args.args.dst.info_v.counts = (ucc_count_t *)rc_node; - args.args.src.info_v.displacements = (ucc_aint_t *)sd_node; - args.args.dst.info_v.displacements = (ucc_aint_t *)rd_node; - UCC_CHECK_GOTO(ucc_coll_init(cl_team->sbgps[UCC_HIER_SBGP_NODE].score_map, - &args, &task_node), - err_init_2, status); - - UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, task_node), err_init_3, - status); - UCC_CHECK_GOTO(ucc_task_subscribe_dep(&schedule->super, task_node, - UCC_EVENT_SCHEDULE_STARTED), - err_init_3, status); -full: - UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, task_full), err_init_3, - status); - - UCC_CHECK_GOTO(ucc_task_subscribe_dep(&schedule->super, task_full, - UCC_EVENT_SCHEDULE_STARTED), - err_init_3, status); - - schedule->super.post = ucc_cl_hier_alltoallv_start; - schedule->super.progress = NULL; - schedule->super.finalize = ucc_cl_hier_alltoallv_finalize; + schedule->super.post = ucc_cl_urom_alltoallv_full_start; + schedule->super.progress = ucc_cl_urom_alltoallv_full_progress; + schedule->super.finalize = ucc_cl_urom_alltoallv_full_finalize; schedule->super.triggered_post = ucc_triggered_post; schedule->super.triggered_post_setup = - ucc_cl_hier_alltoallv_triggered_post_setup; + ucc_cl_urom_alltoallv_triggered_post_setup; + *task = &schedule->super; + cl_debug(cl_lib, "urom coll init'd"); return UCC_OK; - -err_init_3: - if (!full_only) { - ucc_collective_finalize(&task_node->super); - } -err_init_2: - ucc_collective_finalize(&task_full->super); -err_init_1: - if (!full_only) { - ucc_mc_free(cl_schedule->scratch); - } -error: - ucc_cl_hier_put_schedule(schedule); - return status; } diff --git a/src/components/cl/urom/alltoallv/alltoallv.h b/src/components/cl/urom/alltoallv/alltoallv.h index 7ee910e11f..7fa5246e26 100644 --- a/src/components/cl/urom/alltoallv/alltoallv.h +++ b/src/components/cl/urom/alltoallv/alltoallv.h @@ -4,28 +4,26 @@ * See file LICENSE for terms. */ -#ifndef ALLTOALLV_H_ -#define ALLTOALLV_H_ -#include "../cl_hier.h" +#ifndef UROM_ALLTOALLV_H_ +#define UROM_ALLTOALLV_H_ +#include "../cl_urom_coll.h" +#include "../../../tl/ucp/tl_ucp.h" enum { - UCC_CL_HIER_ALLTOALLV_ALG_NODE_SPLIT, - UCC_CL_HIER_ALLTOALLV_ALG_LAST, + UCC_CL_UROM_ALLTOALLV_ALG_FULL, + UCC_CL_UROM_ALLTOALLV_ALG_LAST, }; extern ucc_base_coll_alg_info_t - ucc_cl_hier_alltoallv_algs[UCC_CL_HIER_ALLTOALLV_ALG_LAST + 1]; + ucc_cl_urom_alltoallv_algs[UCC_CL_UROM_ALLTOALLV_ALG_LAST + 1]; -ucc_status_t ucc_cl_hier_alltoallv_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task); -static inline int ucc_cl_hier_alltoallv_alg_from_str(const char *str) +static inline int ucc_cl_urom_alltoallv_alg_from_str(const char *str) { int i; - for (i = 0; i < UCC_CL_HIER_ALLTOALLV_ALG_LAST; i++) { - if (0 == strcasecmp(str, ucc_cl_hier_alltoallv_algs[i].name)) { + for (i = 0; i < UCC_CL_UROM_ALLTOALLV_ALG_LAST; i++) { + if (0 == strcasecmp(str, ucc_cl_urom_alltoallv_algs[i].name)) { break; } } diff --git a/src/components/cl/urom/cl_urom_coll.c b/src/components/cl/urom/cl_urom_coll.c index 39fccaddc8..1350410f0c 100644 --- a/src/components/cl/urom/cl_urom_coll.c +++ b/src/components/cl/urom/cl_urom_coll.c @@ -16,6 +16,10 @@ ucc_status_t ucc_cl_urom_alltoall_full_init( ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task); +ucc_status_t ucc_cl_urom_alltoallv_full_init( + ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, + ucc_coll_task_t **task); + ucc_status_t ucc_cl_urom_allreduce_full_init( ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task); @@ -53,6 +57,8 @@ ucc_status_t ucc_cl_urom_coll_init(ucc_base_coll_args_t *coll_args, switch (coll_args->args.coll_type) { case UCC_COLL_TYPE_ALLTOALL: return ucc_cl_urom_alltoall_full_init(coll_args, team, task); + case UCC_COLL_TYPE_ALLTOALLV: + return ucc_cl_urom_alltoallv_full_init(coll_args, team, task); case UCC_COLL_TYPE_ALLREDUCE: return ucc_cl_urom_allreduce_full_init(coll_args, team, task); default: