From 3c91bf5b36978f4e5f170d38a85ee284b78f9d0f Mon Sep 17 00:00:00 2001 From: ferrol aderholdt Date: Wed, 6 Sep 2023 09:29:22 -0700 Subject: [PATCH] REVIEW: address feedback --- src/components/tl/ucp/alltoallv/alltoallv.c | 39 ---------- src/components/tl/ucp/alltoallv/alltoallv.h | 4 -- .../tl/ucp/alltoallv/alltoallv_onesided.c | 72 +++++++++++++------ 3 files changed, 49 insertions(+), 66 deletions(-) diff --git a/src/components/tl/ucp/alltoallv/alltoallv.c b/src/components/tl/ucp/alltoallv/alltoallv.c index d9cdf06648..063cbd22bf 100644 --- a/src/components/tl/ucp/alltoallv/alltoallv.c +++ b/src/components/tl/ucp/alltoallv/alltoallv.c @@ -51,42 +51,3 @@ ucc_status_t ucc_tl_ucp_alltoallv_pairwise_init(ucc_base_coll_args_t *coll_args, out: return status; } - -ucc_status_t ucc_tl_ucp_alltoallv_onesided_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task_h) -{ - ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); - ucc_tl_ucp_task_t *task; - ucc_status_t status; - - ALLTOALLV_TASK_CHECK(coll_args->args, tl_team); - if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER)) { - tl_error(UCC_TL_TEAM_LIB(tl_team), - "global work buffer not provided nor associated with team"); - status = UCC_ERR_NOT_SUPPORTED; - goto out; - } - if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) { - if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS)) { - tl_error(UCC_TL_TEAM_LIB(tl_team), - "non memory mapped buffers are not supported"); - status = UCC_ERR_NOT_SUPPORTED; - goto out; - } - } - if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_ONESIDED_VECTOR)) { - tl_error(UCC_TL_TEAM_LIB(tl_team), - "onesided vector flag must be set for onesided alltoallv"); - status = UCC_ERR_NOT_SUPPORTED; - goto out; - } - - task = ucc_tl_ucp_init_task(coll_args, team); - *task_h = &task->super; - task->super.post = ucc_tl_ucp_alltoallv_onesided_start; - task->super.progress = ucc_tl_ucp_alltoallv_onesided_progress; - status = UCC_OK; -out: - return status; -} diff --git a/src/components/tl/ucp/alltoallv/alltoallv.h b/src/components/tl/ucp/alltoallv/alltoallv.h index bbf96c762f..9256c62326 100644 --- a/src/components/tl/ucp/alltoallv/alltoallv.h +++ b/src/components/tl/ucp/alltoallv/alltoallv.h @@ -38,10 +38,6 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h); -ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *task); - -void ucc_tl_ucp_alltoallv_onesided_progress(ucc_coll_task_t *task); - ucc_status_t ucc_tl_ucp_alltoallv_pairwise_init_common(ucc_tl_ucp_task_t *task); #define ALLTOALLV_CHECK_INPLACE(_args, _team) \ diff --git a/src/components/tl/ucp/alltoallv/alltoallv_onesided.c b/src/components/tl/ucp/alltoallv/alltoallv_onesided.c index de64cfdac9..7794dc31c7 100644 --- a/src/components/tl/ucp/alltoallv/alltoallv_onesided.c +++ b/src/components/tl/ucp/alltoallv/alltoallv_onesided.c @@ -21,46 +21,33 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask) ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info_v.buffer; ucc_rank_t grank = UCC_TL_TEAM_RANK(team); ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); - ucc_rank_t start = (grank + 1) % gsize; long *pSync = TASK_ARGS(task).global_work_buffer; ucc_aint_t *s_disp = TASK_ARGS(task).src.info_v.displacements; ucc_aint_t *d_disp = TASK_ARGS(task).dst.info_v.displacements; - size_t sdt_size = ucc_dt_size(TASK_ARGS(task).src.info_v.datatype); - ucc_rank_t peer; - size_t sd_disp, dd_disp, data_size; + size_t sdt_size = ucc_dt_size(TASK_ARGS(task).src.info_v.datatype); + size_t rdt_size = ucc_dt_size(TASK_ARGS(task).dst.info_v.datatype); + ucc_rank_t peer; + size_t sd_disp, dd_disp, data_size; ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); /* perform puts to each member i using that displacement */ - sd_disp = ucc_coll_args_get_displacement(&TASK_ARGS(task), s_disp, start) * - sdt_size; - dd_disp = ucc_coll_args_get_displacement(&TASK_ARGS(task), d_disp, start) * - sdt_size; - data_size = - ucc_coll_args_get_count(&TASK_ARGS(task), - TASK_ARGS(task).src.info_v.counts, start) * - sdt_size; - UCPCHECK_GOTO(ucc_tl_ucp_put_nb(((void *)src + sd_disp), - ((void *)dest + dd_disp), data_size, start, - team, task), - task, out); - UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, start, team), task, out); - - for (peer = (start + 1) % gsize; peer != start; peer = (peer + 1) % gsize) { + for (peer = (grank + 1) % gsize; task->onesided.put_posted < gsize; + peer = (peer + 1) % gsize) { sd_disp = ucc_coll_args_get_displacement(&TASK_ARGS(task), s_disp, peer) * sdt_size; dd_disp = ucc_coll_args_get_displacement(&TASK_ARGS(task), d_disp, peer) * - sdt_size; + rdt_size; data_size = ucc_coll_args_get_count(&TASK_ARGS(task), TASK_ARGS(task).src.info_v.counts, peer) * sdt_size; - UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + sd_disp), - (void *)(dest + dd_disp), data_size, - peer, team, task), + UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)PTR_OFFSET(src, sd_disp), + (void *)PTR_OFFSET(dest, dd_disp), + data_size, peer, team, task), task, out); UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, team), task, out); } @@ -85,3 +72,42 @@ void ucc_tl_ucp_alltoallv_onesided_progress(ucc_coll_task_t *ctask) pSync[0] = 0; task->super.status = UCC_OK; } + +ucc_status_t ucc_tl_ucp_alltoallv_onesided_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h) +{ + ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); + ucc_tl_ucp_task_t *task; + ucc_status_t status; + + ALLTOALLV_TASK_CHECK(coll_args->args, tl_team); + if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER)) { + tl_error(UCC_TL_TEAM_LIB(tl_team), + "global work buffer not provided nor associated with team"); + status = UCC_ERR_NOT_SUPPORTED; + goto out; + } + if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) { + if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS)) { + tl_error(UCC_TL_TEAM_LIB(tl_team), + "non memory mapped buffers are not supported"); + status = UCC_ERR_NOT_SUPPORTED; + goto out; + } + } + if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_ONESIDED_VECTOR)) { + tl_error(UCC_TL_TEAM_LIB(tl_team), + "onesided vector flag must be set for onesided alltoallv"); + status = UCC_ERR_NOT_SUPPORTED; + goto out; + } + + task = ucc_tl_ucp_init_task(coll_args, team); + *task_h = &task->super; + task->super.post = ucc_tl_ucp_alltoallv_onesided_start; + task->super.progress = ucc_tl_ucp_alltoallv_onesided_progress; + status = UCC_OK; +out: + return status; +}