From c1ac65b3401da4795c4e313d80f516a463c50e55 Mon Sep 17 00:00:00 2001 From: Nick Sarkauskas Date: Tue, 25 Jun 2024 14:35:16 -0700 Subject: [PATCH] TL/UCP: Refactor, add alltoall --- src/components/tl/ucp/Makefile.am | 1 + src/components/tl/ucp/allgather/allgather.h | 4 +- .../tl/ucp/allgather/allgather_xgvmi.c | 496 +----------------- src/components/tl/ucp/alltoall/alltoall.c | 4 + src/components/tl/ucp/alltoall/alltoall.h | 4 + .../tl/ucp/alltoall/alltoall_xgvmi.c | 73 +++ src/components/tl/ucp/tl_ucp_coll.c | 6 +- src/components/tl/ucp/tl_ucp_coll.h | 3 +- src/components/tl/ucp/tl_ucp_dpu_offload.c | 488 +++++++++++++++++ src/components/tl/ucp/tl_ucp_dpu_offload.h | 17 + 10 files changed, 606 insertions(+), 490 deletions(-) create mode 100644 src/components/tl/ucp/alltoall/alltoall_xgvmi.c diff --git a/src/components/tl/ucp/Makefile.am b/src/components/tl/ucp/Makefile.am index 525ec47d95..9ffbca2501 100644 --- a/src/components/tl/ucp/Makefile.am +++ b/src/components/tl/ucp/Makefile.am @@ -29,6 +29,7 @@ alltoall = \ alltoall/alltoall.c \ alltoall/alltoall_onesided.c \ alltoall/alltoall_pairwise.c \ + alltoall/alltoall_xgvmi.c \ alltoall/alltoall_bruck.c alltoallv = \ diff --git a/src/components/tl/ucp/allgather/allgather.h b/src/components/tl/ucp/allgather/allgather.h index f6e90c3d22..7dea71d139 100644 --- a/src/components/tl/ucp/allgather/allgather.h +++ b/src/components/tl/ucp/allgather/allgather.h @@ -76,9 +76,7 @@ ucc_status_t ucc_tl_ucp_allgather_sparbit_init(ucc_base_coll_args_t *coll_args, ucc_coll_task_t **task_h); /* XGVMI */ -ucc_status_t ucc_tl_ucp_allgather_xgvmi_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task_h); +void ucc_tl_ucp_dpu_xgvmi_rdma_progress_allgather(ucc_coll_task_t *coll_task); /* Uses allgather_kn_radix from config */ ucc_status_t ucc_tl_ucp_allgather_knomial_init(ucc_base_coll_args_t *coll_args, diff --git a/src/components/tl/ucp/allgather/allgather_xgvmi.c b/src/components/tl/ucp/allgather/allgather_xgvmi.c index 7ba0e6ac8a..d7b90eb792 100644 --- a/src/components/tl/ucp/allgather/allgather_xgvmi.c +++ b/src/components/tl/ucp/allgather/allgather_xgvmi.c @@ -4,371 +4,11 @@ * See file LICENSE for terms. */ -#include "allgather.h" -#include "../barrier/barrier.h" #include "tl_ucp_ep.h" +#include "tl_ucp_coll.h" #include "tl_ucp_dpu_offload.h" -ucc_status_t -ucc_tl_ucp_allgather_xgvmi_task_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_tl_ucp_task_t *task) -{ - void *src_buf = coll_args->args.src.info.buffer; - void *dst_buf = coll_args->args.dst.info.buffer; - ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); - ucc_rank_t team_size = UCC_TL_TEAM_SIZE(tl_team); - int inplace = UCC_IS_INPLACE(coll_args->args); - ucc_tl_ucp_allreduce_sw_global_work_buf_info_t - *gwbi_p = NULL; - size_t allgather_size = - sizeof(ucc_tl_ucp_allreduce_sw_host_allgather_t); - ucc_tl_ucp_allreduce_sw_host_allgather_t - *allgather_data; - ucc_rank_t i; - void *buffer; - void *ptr; - size_t bufs_sz, allgather_data_sz, rbufs_sz, dst_rkeys_sz, - dst_ebuf_sz, sbufs_sz, src_rkeys_sz, src_ebuf_sz; - - ucc_assert(team_size > 0); - - bufs_sz = sizeof(ucc_tl_ucp_dpu_offload_buf_info_t); - allgather_data_sz = allgather_size * (team_size + 1); - rbufs_sz = sizeof(void *) * team_size; - dst_rkeys_sz = sizeof(ucp_rkey_h) * team_size; - dst_ebuf_sz = sizeof(struct ucc_tl_ucp_allreduce_sw_export_buf); - - if (!inplace) { - sbufs_sz = sizeof(void *) * team_size; - src_rkeys_sz = sizeof(ucp_rkey_h) * team_size; - src_ebuf_sz = sizeof(struct ucc_tl_ucp_allreduce_sw_export_buf); - } else { - sbufs_sz = 0; - src_rkeys_sz = 0; - src_ebuf_sz = 0; - } - - buffer = ucc_malloc(bufs_sz + allgather_data_sz + rbufs_sz + - dst_rkeys_sz + dst_ebuf_sz + sbufs_sz + - src_rkeys_sz + src_ebuf_sz); - if (buffer == NULL) { - tl_error(UCC_TL_TEAM_LIB(tl_team), "error while allocating task"); - return UCC_ERR_NO_RESOURCE; - } - - ptr = buffer; - - task->allgather_xgvmi.bufs = ptr; - - ptr = allgather_data = PTR_OFFSET(ptr, bufs_sz); - task->allgather_xgvmi.allgather_data = allgather_data; - - gwbi_p = coll_args->args.global_work_buffer; - task->super.bargs.args.global_work_buffer = gwbi_p; - - ptr = task->allgather_xgvmi.bufs->rbufs = PTR_OFFSET(ptr, allgather_data_sz); - ptr = task->allgather_xgvmi.bufs->dst_rkeys = PTR_OFFSET(ptr, rbufs_sz); - for (i = 0; i < team_size; i++) { - task->allgather_xgvmi.bufs->dst_rkeys[i] = NULL; - } - - ptr = task->allgather_xgvmi.bufs->dst_ebuf = PTR_OFFSET(ptr, dst_rkeys_sz); - task->allgather_xgvmi.bufs->dst_ebuf->memh = NULL; - - allgather_data->dst_buf = dst_buf; - - task->allgather_xgvmi.allgather_data = allgather_data; - task->allgather_xgvmi.allgather_task = NULL; - - if (!inplace) { - allgather_data->src_buf = src_buf; - - ptr = task->allgather_xgvmi.bufs->sbufs = PTR_OFFSET(ptr, dst_ebuf_sz); - ptr = task->allgather_xgvmi.bufs->src_rkeys = PTR_OFFSET(ptr, sbufs_sz); - for (i = 0; i < team_size; i++) { - task->allgather_xgvmi.bufs->src_rkeys[i] = NULL; - } - - task->allgather_xgvmi.bufs->src_ebuf = PTR_OFFSET(ptr, src_rkeys_sz); - task->allgather_xgvmi.bufs->src_ebuf->memh = NULL; - } else { - task->allgather_xgvmi.bufs->src_ebuf = NULL; - } - - return UCC_OK; -} - -ucc_status_t ucc_tl_ucp_allgather_xgvmi_allgather_info_finalize( - ucc_tl_ucp_task_t *sw_task) -{ - ucs_status_t ucs_status = UCS_OK; - ucc_base_team_t *base_team = sw_task->super.team; - ucc_tl_ucp_team_t *tl_team = ucc_derived_of(base_team, ucc_tl_ucp_team_t); - ucc_rank_t team_size = UCC_TL_TEAM_SIZE(tl_team); - void *recvbuf = sw_task->allgather_xgvmi. - allgather_task->bargs.args.dst.info.buffer; - ucc_tl_ucp_allreduce_sw_host_allgather_t *all_host_allgather = recvbuf; - ucc_status_t status = UCC_OK; - int inplace = UCC_IS_INPLACE(TASK_ARGS(sw_task)); - ucc_rank_t i; - ucp_ep_h ep; - ucp_rkey_h src_unpacked, dst_unpacked; - - ucc_assert(team_size > 0); - - for (i = 0; i < team_size; i++) { - status = ucc_tl_ucp_get_ep(tl_team, i, &ep); - if (ucc_unlikely(UCC_OK != status)) { - return status; - } - - ucs_status = ucp_ep_rkey_unpack( - ep, all_host_allgather[i].packed_dst_key, &dst_unpacked); - if (UCS_OK != ucs_status) { - tl_error(UCC_TL_TEAM_LIB(tl_team), "dst rkey unpack failed"); - return ucs_status_to_ucc_status(ucs_status); - } - - sw_task->allgather_xgvmi.bufs->rbufs[i] = - all_host_allgather[i].dst_buf; - sw_task->allgather_xgvmi.bufs->dst_rkeys[i] = dst_unpacked; - - if (!inplace) { - ucs_status = ucp_ep_rkey_unpack( - ep, all_host_allgather[i].packed_src_key, &src_unpacked); - if (UCS_OK != ucs_status) { - tl_error(UCC_TL_TEAM_LIB(tl_team), "src rkey unpack failed"); - return ucs_status_to_ucc_status(ucs_status); - } - - sw_task->allgather_xgvmi.bufs->sbufs[i] = - all_host_allgather[i].src_buf; - sw_task->allgather_xgvmi.bufs->src_rkeys[i] = src_unpacked; - } else { - sw_task->allgather_xgvmi.bufs->sbufs = - sw_task->allgather_xgvmi.bufs->rbufs; - sw_task->allgather_xgvmi.bufs->src_rkeys = - sw_task->allgather_xgvmi.bufs->dst_rkeys; - } - } - - return status; -} - -void -ucc_tl_ucp_allgather_xgvmi_free_task(ucc_coll_task_t *coll_task) -{ - ucc_base_team_t *team = coll_task->team; - ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - int inplace = UCC_IS_INPLACE(coll_task->bargs.args); - ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(tl_team); - - if (task->allgather_xgvmi.bufs) { - if (!inplace) { - if (task->allgather_xgvmi.bufs->src_ebuf->memh != NULL) { - ucp_mem_unmap(tl_ctx->worker.ucp_context, - task->allgather_xgvmi.bufs->src_ebuf->memh); - task->allgather_xgvmi.bufs->src_ebuf->memh = NULL; - } - } - - if (task->allgather_xgvmi.bufs->dst_ebuf->memh != NULL) { - ucp_mem_unmap(tl_ctx->worker.ucp_context, - task->allgather_xgvmi.bufs->dst_ebuf->memh); - } - ucc_free(task->allgather_xgvmi.bufs); - } -} - -ucc_status_t -ucc_tl_ucp_allgather_xgvmi_start(ucc_coll_task_t *coll_task) -{ - ucc_base_coll_args_t *coll_args = &coll_task->bargs; - ucc_schedule_t *schedule = ucc_derived_of(coll_task, - ucc_schedule_t); - ucc_base_team_t *base_team = schedule->super.team; - ucc_tl_ucp_team_t *team = ucc_derived_of(base_team, - ucc_tl_ucp_team_t); - ucc_rank_t team_size = UCC_TL_TEAM_SIZE(team); - ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(team); - int inplace = UCC_IS_INPLACE(coll_args->args); - ucc_status_t status = UCC_OK; - ucc_tl_ucp_allreduce_sw_global_work_buf_info_t - *gwbi_p = coll_args->args.global_work_buffer; - ucc_tl_ucp_task_t *rdma_task = ucc_derived_of(schedule->tasks[0], - ucc_tl_ucp_task_t); - ucc_tl_ucp_allreduce_sw_host_allgather_t *allgather_data; - - allgather_data = rdma_task->allgather_xgvmi.allgather_data; - - rdma_task->allgather_xgvmi.gets_posted = 0; - rdma_task->allgather_xgvmi.gets_completed = 0; - memset(rdma_task->allgather_xgvmi.requests, 0, - team_size * sizeof(sizeof(ucs_status_ptr_t))); - - // Register the src buf - if (!inplace) { - status = ucc_tl_ucp_allreduce_sliding_window_register( - tl_ctx->worker.ucp_context, team, - rdma_task->allgather_xgvmi.bufs->src_ebuf, - gwbi_p->packed_src_memh); - if (status != UCC_OK) { - tl_error(UCC_TASK_LIB(rdma_task), "failed to register src memh: %s", - ucc_status_string(status)); - goto out; - } - ucc_assert( - rdma_task->allgather_xgvmi.bufs->src_ebuf->packed_key_len - <= ALLREDUCE_PACKED_KEY_MAX_LEN); - memcpy(allgather_data->packed_src_key, - rdma_task->allgather_xgvmi.bufs->src_ebuf->packed_key, - rdma_task->allgather_xgvmi.bufs->src_ebuf->packed_key_len); - } - - // Register the dst buf - status = ucc_tl_ucp_allreduce_sliding_window_register( - tl_ctx->worker.ucp_context, team, - rdma_task->allgather_xgvmi.bufs->dst_ebuf, - gwbi_p->packed_dst_memh); - if (status != UCC_OK) { - tl_error(UCC_TASK_LIB(rdma_task), "failed to register dst memh: %s", - ucc_status_string(status)); - goto out; - } - ucc_assert( - rdma_task->allgather_xgvmi.bufs->dst_ebuf->packed_key_len - <= ALLREDUCE_PACKED_KEY_MAX_LEN); - memcpy(allgather_data->packed_dst_key, - rdma_task->allgather_xgvmi.bufs->dst_ebuf->packed_key, - rdma_task->allgather_xgvmi.bufs->dst_ebuf->packed_key_len); - - UCC_CHECK_GOTO(ucc_tl_ucp_allgather_ring_start( - rdma_task->allgather_xgvmi.allgather_task), - out, status); - - return ucc_schedule_start(coll_task); - -out: - tl_error(UCC_TASK_LIB(rdma_task), "failed to start allgather sliding window: %s", - ucc_status_string(status)); - return status; -} - -ucc_status_t -ucc_tl_ucp_allgather_xgvmi_finalize(ucc_coll_task_t *coll_task) -{ - ucc_schedule_t *schedule = ucc_derived_of(coll_task, ucc_schedule_t); - ucc_status_t status; - - status = ucc_schedule_finalize(coll_task); - ucc_tl_ucp_put_schedule(schedule); - - return status; -} - -ucc_status_t -ucc_tl_ucp_allgather_xgvmi_rdma_task_post( - ucc_coll_task_t *coll_task) -{ - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, - ucc_tl_ucp_task_t); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - - ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); - - return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); -} - -static inline void ucc_tl_ucp_allgather_xgvmi_free_rkeys( - ucc_coll_task_t *coll_task) -{ - ucc_base_team_t *team = coll_task->team; - ucc_rank_t team_size = (ucc_rank_t)team->params.size; - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - int inplace = UCC_IS_INPLACE(coll_task->bargs.args); - ucc_rank_t i; - - for (i = 0; i < team_size; i++) { - if (!inplace && task->allgather_xgvmi.bufs->src_rkeys[i] != NULL) { - ucp_rkey_destroy(task->allgather_xgvmi.bufs->src_rkeys[i]); - } - if (task->allgather_xgvmi.bufs->dst_rkeys[i] != NULL) { - ucp_rkey_destroy(task->allgather_xgvmi.bufs->dst_rkeys[i]); - } - } -} - -static ucc_status_t -ucc_tl_ucp_allgather_xgvmi_rdma_task_finalize( - ucc_coll_task_t *coll_task) -{ - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - ucc_status_t st = UCC_OK; - - ucc_tl_ucp_allgather_xgvmi_free_rkeys(coll_task); - ucc_tl_ucp_allgather_xgvmi_free_task(coll_task); - - st = ucc_tl_ucp_coll_finalize(coll_task); - - if (ucc_unlikely(st != UCC_OK)) { - tl_error(UCC_TASK_LIB(task), "failed to finalize collective"); - } - - return st; -} - -static inline ucc_status_t -ucc_tl_ucp_allgather_xgvmi_req_test(ucs_status_ptr_t request, - ucc_tl_ucp_task_t *task) -{ - if (request == NULL) { - return UCC_OK; - } else if (UCS_PTR_IS_ERR(request)) { - tl_error(UCC_TASK_LIB(task), "unable to complete UCX request=%p: %d", - request, UCS_PTR_STATUS(request)); - return ucs_status_to_ucc_status(UCS_PTR_STATUS(request)); - } else { - return ucs_status_to_ucc_status(ucp_request_check_status(request)); - } -} - -static inline void ucc_tl_ucp_allgather_xgvmi_key_exchange_progress( - ucc_coll_task_t *coll_task) -{ - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - ucc_coll_task_t *allgather_task = - task->allgather_xgvmi.allgather_task; - ucc_status_t status = allgather_task->super.status; - - if (status < 0) { - goto err; - } - if (UCC_INPROGRESS == status) { - ucc_tl_ucp_allgather_ring_progress(allgather_task); - return; - } - ucc_assert(status == UCC_OK); - - // copy from allgather recvbuf into xgvmi task - UCC_CHECK_GOTO( - ucc_tl_ucp_allgather_xgvmi_allgather_info_finalize(task), - err, status); - -out: - ucc_tl_ucp_coll_finalize(allgather_task); - task->allgather_xgvmi.allgather_task = NULL; - return; -err: - ucc_tl_ucp_allgather_xgvmi_free_task(coll_task); - tl_error(coll_task->team->context->lib, - "key exchange failure: %s", - ucc_status_string(status)); - goto out; -} - -void ucc_tl_ucp_allgather_xgvmi_rdma_progress(ucc_coll_task_t *coll_task) +void ucc_tl_ucp_dpu_xgvmi_rdma_progress_allgather(ucc_coll_task_t *coll_task) { ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); ucc_rank_t size = (ucc_rank_t)task->subset.map.ep_num; @@ -378,15 +18,15 @@ void ucc_tl_ucp_allgather_xgvmi_rdma_progress(ucc_coll_task_t *coll_task) uint32_t host_team_size = size; ucc_base_team_t *base_team = coll_task->team; ucc_tl_ucp_team_t *tl_team = ucc_derived_of(base_team, ucc_tl_ucp_team_t); - ucc_coll_task_t *allgather_task = task->allgather_xgvmi.allgather_task; + ucc_coll_task_t *allgather_task = task->dpu_xgvmi.allgather_task; ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(tl_team); ucp_request_param_t req_param = {0}; int i = 0; ucc_rank_t rank = UCC_TL_TEAM_RANK(tl_team); size_t data_size = (count * dt_size); - ucs_status_ptr_t *requests = task->allgather_xgvmi.requests; - int *posted = &task->allgather_xgvmi.gets_posted; - int *completed = &task->allgather_xgvmi.gets_completed; + ucs_status_ptr_t *requests = task->dpu_xgvmi.requests; + int *posted = &task->dpu_xgvmi.gets_posted; + int *completed = &task->dpu_xgvmi.gets_completed; void *src_addr; void *dst_addr; ucp_rkey_h rkey; @@ -394,7 +34,7 @@ void ucc_tl_ucp_allgather_xgvmi_rdma_progress(ucc_coll_task_t *coll_task) ucc_rank_t offset; if (allgather_task != NULL) { - ucc_tl_ucp_allgather_xgvmi_key_exchange_progress(coll_task); + ucc_tl_ucp_dpu_xgvmi_key_exchange_progress(coll_task); return; } @@ -402,10 +42,10 @@ void ucc_tl_ucp_allgather_xgvmi_rdma_progress(ucc_coll_task_t *coll_task) for (i = *posted; i < host_team_size; i++) { offset = (i + rank) % host_team_size; - req_param.memh = task->allgather_xgvmi.bufs->dst_ebuf->memh; - src_addr = task->allgather_xgvmi.bufs->sbufs[offset]; - dst_addr = PTR_OFFSET(task->allgather_xgvmi.bufs->rbufs[rank], offset * data_size); - rkey = task->allgather_xgvmi.bufs->src_rkeys[offset]; + req_param.memh = task->dpu_xgvmi.bufs->dst_ebuf->memh; + src_addr = task->dpu_xgvmi.bufs->sbufs[offset]; + dst_addr = PTR_OFFSET(task->dpu_xgvmi.bufs->rbufs[rank], offset * data_size); + rkey = task->dpu_xgvmi.bufs->src_rkeys[offset]; ucc_tl_ucp_get_ep(tl_team, offset, &ep); requests[i] = ucp_get_nbx( @@ -419,7 +59,7 @@ void ucc_tl_ucp_allgather_xgvmi_rdma_progress(ucc_coll_task_t *coll_task) ucp_worker_progress(tl_ctx->worker.ucp_worker); for (i = *completed; i < *posted; i++) { - if (ucc_tl_ucp_allgather_xgvmi_req_test(requests[i], task) == UCC_OK) { + if (ucc_tl_ucp_dpu_xgvmi_req_test(requests[i], task) == UCC_OK) { if (requests[i]) ucp_request_free(requests[i]); *completed += 1; } else { @@ -431,115 +71,3 @@ void ucc_tl_ucp_allgather_xgvmi_rdma_progress(ucc_coll_task_t *coll_task) task->super.status = UCC_OK; } } - -ucc_status_t -ucc_tl_ucp_allgather_xgvmi_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task_h) -{ - ucc_schedule_t *schedule = NULL; - ucc_status_t status = UCC_OK; - ucc_tl_ucp_team_t *tl_team = - ucc_derived_of(team, ucc_tl_ucp_team_t); - size_t allgather_size = - sizeof(ucc_tl_ucp_allreduce_sw_host_allgather_t); - ucc_rank_t size = UCC_TL_TEAM_SIZE(tl_team); - ucc_base_coll_args_t bargs = { - .mask = 0, - .args = { - .coll_type = UCC_COLL_TYPE_ALLGATHER, - .mask = 0, - .src.info = {.buffer = NULL, - .count = allgather_size, - .datatype = UCC_DT_UINT8, - .mem_type = UCC_MEMORY_TYPE_HOST}, - .dst.info = {.buffer = NULL, - .count = allgather_size * size, - .datatype = UCC_DT_UINT8, - .mem_type = UCC_MEMORY_TYPE_HOST} - } - }; - ucc_base_coll_args_t barrier_coll_args = { - .team = team->params.team, - .args.coll_type = UCC_COLL_TYPE_BARRIER, - }; - ucc_tl_ucp_allreduce_sw_host_allgather_t *allgather_data; - ucc_tl_ucp_task_t *rdma_task; - ucc_coll_task_t *barrier_task; - - status = ucc_tl_ucp_get_schedule(tl_team, coll_args, - (ucc_tl_ucp_schedule_t **)&schedule); - if (ucc_unlikely(UCC_OK != status)) { - return status; - } - - *task_h = &schedule->super; - schedule->super.post = ucc_tl_ucp_allgather_xgvmi_start; - schedule->super.progress = NULL; - schedule->super.finalize = ucc_tl_ucp_allgather_xgvmi_finalize; - - schedule->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; - - rdma_task = ucc_tl_ucp_init_task(coll_args, team); - if (ucc_unlikely(!rdma_task)) { - tl_error(UCC_TL_TEAM_LIB(tl_team), "Couldnt allocate task"); - return UCC_ERR_NO_MEMORY; - } - - status = ucc_tl_ucp_allgather_xgvmi_task_init(coll_args, team, - rdma_task); - if (status != UCC_OK) { - tl_error(UCC_TL_TEAM_LIB(tl_team), "failed to init task: %s", - ucc_status_string(status)); - goto out; - } - - allgather_data = rdma_task->allgather_xgvmi.allgather_data; - bargs.args.src.info.buffer = allgather_data; - bargs.args.dst.info.buffer = PTR_OFFSET(allgather_data, allgather_size); - - rdma_task->super.post = ucc_tl_ucp_allgather_xgvmi_rdma_task_post; - rdma_task->super.progress = ucc_tl_ucp_allgather_xgvmi_rdma_progress; - rdma_task->super.finalize = ucc_tl_ucp_allgather_xgvmi_rdma_task_finalize; - - rdma_task->allgather_xgvmi.requests = ucc_malloc(sizeof(ucs_status_ptr_t) * size); - - UCC_CHECK_GOTO(ucc_tl_ucp_allgather_ring_init(&bargs, team, - &rdma_task->allgather_xgvmi.allgather_task), - free_rdma_task, status); - - status = ucc_tl_ucp_coll_init(&barrier_coll_args, team, - &barrier_task); - if (status < 0) { - tl_error(UCC_TL_TEAM_LIB(tl_team), - "failure during sliding window barrier init: %s", - ucc_status_string(status)); - goto free_allgather_task; - } - - UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, &rdma_task->super), out, status); - UCC_CHECK_GOTO(ucc_event_manager_subscribe(&schedule->super, - UCC_EVENT_SCHEDULE_STARTED, - &rdma_task->super, - ucc_task_start_handler), - free_barrier_task, status); - - UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, barrier_task), out, status); - UCC_CHECK_GOTO(ucc_event_manager_subscribe(&rdma_task->super, - UCC_EVENT_COMPLETED, - barrier_task, - ucc_task_start_handler), - free_barrier_task, status); - - return status; - -free_barrier_task: - ucc_tl_ucp_coll_finalize(barrier_task); -free_allgather_task: - ucc_tl_ucp_coll_finalize(rdma_task->allgather_xgvmi.allgather_task); -free_rdma_task: - ucc_tl_ucp_allgather_xgvmi_free_task(&rdma_task->super); -out: - ucc_tl_ucp_put_schedule(schedule); - return status; -} diff --git a/src/components/tl/ucp/alltoall/alltoall.c b/src/components/tl/ucp/alltoall/alltoall.c index 3803d96426..c65f75e031 100644 --- a/src/components/tl/ucp/alltoall/alltoall.c +++ b/src/components/tl/ucp/alltoall/alltoall.c @@ -43,6 +43,10 @@ ucc_base_coll_alg_info_t {.id = UCC_TL_UCP_ALLTOALL_ALG_ONESIDED, .name = "onesided", .desc = "naive, linear one-sided implementation"}, + [UCC_TL_UCP_ALLTOALL_ALG_XGVMI] = + {.id = UCC_TL_UCP_ALLTOALL_ALG_XGVMI, + .name = "xgvmi", + .desc = "xgvmi-based implementation"}, [UCC_TL_UCP_ALLTOALL_ALG_LAST] = {.id = 0, .name = NULL, .desc = NULL}}; ucc_status_t ucc_tl_ucp_alltoall_init(ucc_tl_ucp_task_t *task) diff --git a/src/components/tl/ucp/alltoall/alltoall.h b/src/components/tl/ucp/alltoall/alltoall.h index 746f3fcc47..be1ccf2b1d 100644 --- a/src/components/tl/ucp/alltoall/alltoall.h +++ b/src/components/tl/ucp/alltoall/alltoall.h @@ -14,6 +14,7 @@ enum { UCC_TL_UCP_ALLTOALL_ALG_PAIRWISE, UCC_TL_UCP_ALLTOALL_ALG_BRUCK, UCC_TL_UCP_ALLTOALL_ALG_ONESIDED, + UCC_TL_UCP_ALLTOALL_ALG_XGVMI, UCC_TL_UCP_ALLTOALL_ALG_LAST }; @@ -42,6 +43,9 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h); +/* XGVMI */ +void ucc_tl_ucp_dpu_xgvmi_rdma_progress_alltoall(ucc_coll_task_t *coll_task); + #define ALLTOALL_CHECK_INPLACE(_args, _team) \ do { \ if (UCC_IS_INPLACE(_args)) { \ diff --git a/src/components/tl/ucp/alltoall/alltoall_xgvmi.c b/src/components/tl/ucp/alltoall/alltoall_xgvmi.c new file mode 100644 index 0000000000..26b24aa7ab --- /dev/null +++ b/src/components/tl/ucp/alltoall/alltoall_xgvmi.c @@ -0,0 +1,73 @@ +/** + * Copyright(c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_ucp_ep.h" +#include "tl_ucp_coll.h" +#include "tl_ucp_dpu_offload.h" + +void ucc_tl_ucp_dpu_xgvmi_rdma_progress_alltoall(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_rank_t size = (ucc_rank_t)task->subset.map.ep_num; + ucc_datatype_t dtype = TASK_ARGS(task).src.info.datatype; + size_t dt_size = ucc_dt_size(dtype); + uint32_t count = coll_task->bargs.args.src.info.count; + uint32_t host_team_size = size; + ucc_base_team_t *base_team = coll_task->team; + ucc_tl_ucp_team_t *tl_team = ucc_derived_of(base_team, ucc_tl_ucp_team_t); + ucc_coll_task_t *allgather_task = task->dpu_xgvmi.allgather_task; + ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(tl_team); + ucp_request_param_t req_param = {0}; + int i = 0; + ucc_rank_t rank = UCC_TL_TEAM_RANK(tl_team); + size_t data_size = (count * dt_size) / host_team_size; + ucs_status_ptr_t *requests = task->dpu_xgvmi.requests; + int *posted = &task->dpu_xgvmi.gets_posted; + int *completed = &task->dpu_xgvmi.gets_completed; + void *src_addr; + void *dst_addr; + ucp_rkey_h rkey; + ucp_ep_h ep; + ucc_rank_t offset; + + if (allgather_task != NULL) { + ucc_tl_ucp_dpu_xgvmi_key_exchange_progress(coll_task); + return; + } + + req_param.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMH; + + for (i = *posted; i < host_team_size; i++) { + offset = (i + rank) % host_team_size; + req_param.memh = task->dpu_xgvmi.bufs->dst_ebuf->memh; + src_addr = PTR_OFFSET(task->dpu_xgvmi.bufs->sbufs[offset], rank * data_size); + dst_addr = PTR_OFFSET(task->dpu_xgvmi.bufs->rbufs[rank], offset * data_size); + rkey = task->dpu_xgvmi.bufs->src_rkeys[offset]; + ucc_tl_ucp_get_ep(tl_team, offset, &ep); + + requests[i] = ucp_get_nbx( + ep, dst_addr, + data_size, (uint64_t)src_addr, + rkey, &req_param); + + *posted += 1; + } + + ucp_worker_progress(tl_ctx->worker.ucp_worker); + + for (i = *completed; i < *posted; i++) { + if (ucc_tl_ucp_dpu_xgvmi_req_test(requests[i], task) == UCC_OK) { + if (requests[i]) ucp_request_free(requests[i]); + *completed += 1; + } else { + break; + } + } + + if (*completed == host_team_size) { + task->super.status = UCC_OK; + } +} diff --git a/src/components/tl/ucp/tl_ucp_coll.c b/src/components/tl/ucp/tl_ucp_coll.c index 76e5792f08..acb066908a 100644 --- a/src/components/tl/ucp/tl_ucp_coll.c +++ b/src/components/tl/ucp/tl_ucp_coll.c @@ -23,6 +23,7 @@ #include "fanin/fanin.h" #include "fanout/fanout.h" #include "scatterv/scatterv.h" +#include "tl_ucp_dpu_offload.h" const ucc_tl_ucp_default_alg_desc_t ucc_tl_ucp_default_alg_descs[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR] = { @@ -269,7 +270,7 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str, *init = ucc_tl_ucp_allgather_sparbit_init; break; case UCC_TL_UCP_ALLGATHER_ALG_XGVMI: - *init = ucc_tl_ucp_allgather_xgvmi_init; + *init = ucc_tl_ucp_dpu_xgvmi_init; break; default: status = UCC_ERR_INVALID_PARAM; @@ -322,6 +323,9 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str, case UCC_TL_UCP_ALLTOALL_ALG_ONESIDED: *init = ucc_tl_ucp_alltoall_onesided_init; break; + case UCC_TL_UCP_ALLTOALL_ALG_XGVMI: + *init = ucc_tl_ucp_dpu_xgvmi_init; + break; default: status = UCC_ERR_INVALID_PARAM; break; diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index b6514bf8b0..97c829b03c 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -137,14 +137,13 @@ typedef struct ucc_tl_ucp_task { ucc_tl_ucp_dpu_offload_buf_info_t *bufs; } allreduce_sliding_window; struct { - ucs_status_ptr_t *put_requests; ucc_tl_ucp_allreduce_sw_host_allgather *allgather_data; ucc_coll_task_t *allgather_task; ucc_tl_ucp_dpu_offload_buf_info_t *bufs; ucs_status_ptr_t *requests; int gets_posted; int gets_completed; - } allgather_xgvmi; + } dpu_xgvmi; struct { int phase; ucc_knomial_pattern_t p; diff --git a/src/components/tl/ucp/tl_ucp_dpu_offload.c b/src/components/tl/ucp/tl_ucp_dpu_offload.c index 70956ad9e4..f5b4cfeea5 100644 --- a/src/components/tl/ucp/tl_ucp_dpu_offload.c +++ b/src/components/tl/ucp/tl_ucp_dpu_offload.c @@ -5,6 +5,11 @@ */ #include "tl_ucp_dpu_offload.h" +#include "allgather/allgather.h" +#include "alltoall/alltoall.h" +#include "barrier/barrier.h" +#include "tl_ucp_ep.h" +#include "tl_ucp_dpu_offload.h" ucc_status_t ucc_tl_ucp_allreduce_sliding_window_register( ucp_context_h ucp_context, ucc_tl_ucp_team_t *tl_team, @@ -40,3 +45,486 @@ ucc_status_t ucc_tl_ucp_allreduce_sliding_window_register( return UCC_OK; } + +ucc_status_t +ucc_tl_ucp_dpu_xgvmi_task_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_tl_ucp_task_t *task) +{ + void *src_buf = coll_args->args.src.info.buffer; + void *dst_buf = coll_args->args.dst.info.buffer; + ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); + ucc_rank_t team_size = UCC_TL_TEAM_SIZE(tl_team); + int inplace = UCC_IS_INPLACE(coll_args->args); + ucc_tl_ucp_allreduce_sw_global_work_buf_info_t + *gwbi_p = NULL; + size_t allgather_size = + sizeof(ucc_tl_ucp_allreduce_sw_host_allgather_t); + ucc_tl_ucp_allreduce_sw_host_allgather_t + *allgather_data; + ucc_rank_t i; + void *buffer; + void *ptr; + size_t bufs_sz, allgather_data_sz, rbufs_sz, dst_rkeys_sz, + dst_ebuf_sz, sbufs_sz, src_rkeys_sz, src_ebuf_sz; + + ucc_assert(team_size > 0); + + bufs_sz = sizeof(ucc_tl_ucp_dpu_offload_buf_info_t); + allgather_data_sz = allgather_size * (team_size + 1); + rbufs_sz = sizeof(void *) * team_size; + dst_rkeys_sz = sizeof(ucp_rkey_h) * team_size; + dst_ebuf_sz = sizeof(struct ucc_tl_ucp_allreduce_sw_export_buf); + + if (!inplace) { + sbufs_sz = sizeof(void *) * team_size; + src_rkeys_sz = sizeof(ucp_rkey_h) * team_size; + src_ebuf_sz = sizeof(struct ucc_tl_ucp_allreduce_sw_export_buf); + } else { + sbufs_sz = 0; + src_rkeys_sz = 0; + src_ebuf_sz = 0; + } + + buffer = ucc_malloc(bufs_sz + allgather_data_sz + rbufs_sz + + dst_rkeys_sz + dst_ebuf_sz + sbufs_sz + + src_rkeys_sz + src_ebuf_sz); + if (buffer == NULL) { + tl_error(UCC_TL_TEAM_LIB(tl_team), "error while allocating task"); + return UCC_ERR_NO_RESOURCE; + } + + ptr = buffer; + + task->dpu_xgvmi.bufs = ptr; + + ptr = allgather_data = PTR_OFFSET(ptr, bufs_sz); + task->dpu_xgvmi.allgather_data = allgather_data; + + gwbi_p = coll_args->args.global_work_buffer; + task->super.bargs.args.global_work_buffer = gwbi_p; + + ptr = task->dpu_xgvmi.bufs->rbufs = PTR_OFFSET(ptr, allgather_data_sz); + ptr = task->dpu_xgvmi.bufs->dst_rkeys = PTR_OFFSET(ptr, rbufs_sz); + for (i = 0; i < team_size; i++) { + task->dpu_xgvmi.bufs->dst_rkeys[i] = NULL; + } + + ptr = task->dpu_xgvmi.bufs->dst_ebuf = PTR_OFFSET(ptr, dst_rkeys_sz); + task->dpu_xgvmi.bufs->dst_ebuf->memh = NULL; + + allgather_data->dst_buf = dst_buf; + + task->dpu_xgvmi.allgather_data = allgather_data; + task->dpu_xgvmi.allgather_task = NULL; + + if (!inplace) { + allgather_data->src_buf = src_buf; + + ptr = task->dpu_xgvmi.bufs->sbufs = PTR_OFFSET(ptr, dst_ebuf_sz); + ptr = task->dpu_xgvmi.bufs->src_rkeys = PTR_OFFSET(ptr, sbufs_sz); + for (i = 0; i < team_size; i++) { + task->dpu_xgvmi.bufs->src_rkeys[i] = NULL; + } + + task->dpu_xgvmi.bufs->src_ebuf = PTR_OFFSET(ptr, src_rkeys_sz); + task->dpu_xgvmi.bufs->src_ebuf->memh = NULL; + } else { + task->dpu_xgvmi.bufs->src_ebuf = NULL; + } + + return UCC_OK; +} + +ucc_status_t ucc_tl_ucp_dpu_xgvmi_allgather_info_finalize( + ucc_tl_ucp_task_t *sw_task) +{ + ucs_status_t ucs_status = UCS_OK; + ucc_base_team_t *base_team = sw_task->super.team; + ucc_tl_ucp_team_t *tl_team = ucc_derived_of(base_team, ucc_tl_ucp_team_t); + ucc_rank_t team_size = UCC_TL_TEAM_SIZE(tl_team); + void *recvbuf = sw_task->dpu_xgvmi. + allgather_task->bargs.args.dst.info.buffer; + ucc_tl_ucp_allreduce_sw_host_allgather_t *all_host_allgather = recvbuf; + ucc_status_t status = UCC_OK; + int inplace = UCC_IS_INPLACE(TASK_ARGS(sw_task)); + ucc_rank_t i; + ucp_ep_h ep; + ucp_rkey_h src_unpacked, dst_unpacked; + + ucc_assert(team_size > 0); + + for (i = 0; i < team_size; i++) { + status = ucc_tl_ucp_get_ep(tl_team, i, &ep); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + + ucs_status = ucp_ep_rkey_unpack( + ep, all_host_allgather[i].packed_dst_key, &dst_unpacked); + if (UCS_OK != ucs_status) { + tl_error(UCC_TL_TEAM_LIB(tl_team), "dst rkey unpack failed"); + return ucs_status_to_ucc_status(ucs_status); + } + + sw_task->dpu_xgvmi.bufs->rbufs[i] = + all_host_allgather[i].dst_buf; + sw_task->dpu_xgvmi.bufs->dst_rkeys[i] = dst_unpacked; + + if (!inplace) { + ucs_status = ucp_ep_rkey_unpack( + ep, all_host_allgather[i].packed_src_key, &src_unpacked); + if (UCS_OK != ucs_status) { + tl_error(UCC_TL_TEAM_LIB(tl_team), "src rkey unpack failed"); + return ucs_status_to_ucc_status(ucs_status); + } + + sw_task->dpu_xgvmi.bufs->sbufs[i] = + all_host_allgather[i].src_buf; + sw_task->dpu_xgvmi.bufs->src_rkeys[i] = src_unpacked; + } else { + sw_task->dpu_xgvmi.bufs->sbufs = + sw_task->dpu_xgvmi.bufs->rbufs; + sw_task->dpu_xgvmi.bufs->src_rkeys = + sw_task->dpu_xgvmi.bufs->dst_rkeys; + } + } + + return status; +} + +void +ucc_tl_ucp_dpu_xgvmi_free_task(ucc_coll_task_t *coll_task) +{ + ucc_base_team_t *team = coll_task->team; + ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + int inplace = UCC_IS_INPLACE(coll_task->bargs.args); + ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(tl_team); + + if (task->dpu_xgvmi.bufs) { + if (!inplace) { + if (task->dpu_xgvmi.bufs->src_ebuf->memh != NULL) { + ucp_mem_unmap(tl_ctx->worker.ucp_context, + task->dpu_xgvmi.bufs->src_ebuf->memh); + task->dpu_xgvmi.bufs->src_ebuf->memh = NULL; + } + } + + if (task->dpu_xgvmi.bufs->dst_ebuf->memh != NULL) { + ucp_mem_unmap(tl_ctx->worker.ucp_context, + task->dpu_xgvmi.bufs->dst_ebuf->memh); + } + ucc_free(task->dpu_xgvmi.bufs); + } +} + +ucc_status_t +ucc_tl_ucp_dpu_xgvmi_start(ucc_coll_task_t *coll_task) +{ + ucc_base_coll_args_t *coll_args = &coll_task->bargs; + ucc_schedule_t *schedule = ucc_derived_of(coll_task, + ucc_schedule_t); + ucc_base_team_t *base_team = schedule->super.team; + ucc_tl_ucp_team_t *team = ucc_derived_of(base_team, + ucc_tl_ucp_team_t); + ucc_rank_t team_size = UCC_TL_TEAM_SIZE(team); + ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(team); + int inplace = UCC_IS_INPLACE(coll_args->args); + ucc_status_t status = UCC_OK; + ucc_tl_ucp_allreduce_sw_global_work_buf_info_t + *gwbi_p = coll_args->args.global_work_buffer; + ucc_tl_ucp_task_t *rdma_task = ucc_derived_of(schedule->tasks[0], + ucc_tl_ucp_task_t); + ucc_tl_ucp_allreduce_sw_host_allgather_t *allgather_data; + + allgather_data = rdma_task->dpu_xgvmi.allgather_data; + + rdma_task->dpu_xgvmi.gets_posted = 0; + rdma_task->dpu_xgvmi.gets_completed = 0; + memset(rdma_task->dpu_xgvmi.requests, 0, + team_size * sizeof(sizeof(ucs_status_ptr_t))); + + // Register the src buf + if (!inplace) { + status = ucc_tl_ucp_allreduce_sliding_window_register( + tl_ctx->worker.ucp_context, team, + rdma_task->dpu_xgvmi.bufs->src_ebuf, + gwbi_p->packed_src_memh); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(rdma_task), "failed to register src memh: %s", + ucc_status_string(status)); + goto out; + } + ucc_assert( + rdma_task->dpu_xgvmi.bufs->src_ebuf->packed_key_len + <= ALLREDUCE_PACKED_KEY_MAX_LEN); + memcpy(allgather_data->packed_src_key, + rdma_task->dpu_xgvmi.bufs->src_ebuf->packed_key, + rdma_task->dpu_xgvmi.bufs->src_ebuf->packed_key_len); + } + + // Register the dst buf + status = ucc_tl_ucp_allreduce_sliding_window_register( + tl_ctx->worker.ucp_context, team, + rdma_task->dpu_xgvmi.bufs->dst_ebuf, + gwbi_p->packed_dst_memh); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(rdma_task), "failed to register dst memh: %s", + ucc_status_string(status)); + goto out; + } + ucc_assert( + rdma_task->dpu_xgvmi.bufs->dst_ebuf->packed_key_len + <= ALLREDUCE_PACKED_KEY_MAX_LEN); + memcpy(allgather_data->packed_dst_key, + rdma_task->dpu_xgvmi.bufs->dst_ebuf->packed_key, + rdma_task->dpu_xgvmi.bufs->dst_ebuf->packed_key_len); + + UCC_CHECK_GOTO(ucc_tl_ucp_allgather_ring_start( + rdma_task->dpu_xgvmi.allgather_task), + out, status); + + return ucc_schedule_start(coll_task); + +out: + tl_error(UCC_TASK_LIB(rdma_task), "failed to start allgather sliding window: %s", + ucc_status_string(status)); + return status; +} + +ucc_status_t +ucc_tl_ucp_dpu_xgvmi_finalize(ucc_coll_task_t *coll_task) +{ + ucc_schedule_t *schedule = ucc_derived_of(coll_task, ucc_schedule_t); + ucc_status_t status; + + status = ucc_schedule_finalize(coll_task); + ucc_tl_ucp_put_schedule(schedule); + + return status; +} + +ucc_status_t +ucc_tl_ucp_dpu_xgvmi_rdma_task_post( + ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, + ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + + ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); +} + +void ucc_tl_ucp_dpu_xgvmi_free_rkeys( + ucc_coll_task_t *coll_task) +{ + ucc_base_team_t *team = coll_task->team; + ucc_rank_t team_size = (ucc_rank_t)team->params.size; + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + int inplace = UCC_IS_INPLACE(coll_task->bargs.args); + ucc_rank_t i; + + for (i = 0; i < team_size; i++) { + if (!inplace && task->dpu_xgvmi.bufs->src_rkeys[i] != NULL) { + ucp_rkey_destroy(task->dpu_xgvmi.bufs->src_rkeys[i]); + } + if (task->dpu_xgvmi.bufs->dst_rkeys[i] != NULL) { + ucp_rkey_destroy(task->dpu_xgvmi.bufs->dst_rkeys[i]); + } + } +} + +ucc_status_t +ucc_tl_ucp_dpu_xgvmi_rdma_task_finalize( + ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_status_t st = UCC_OK; + + ucc_tl_ucp_dpu_xgvmi_free_rkeys(coll_task); + ucc_tl_ucp_dpu_xgvmi_free_task(coll_task); + + st = ucc_tl_ucp_coll_finalize(coll_task); + + if (ucc_unlikely(st != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), "failed to finalize collective"); + } + + return st; +} + +ucc_status_t +ucc_tl_ucp_dpu_xgvmi_req_test(ucs_status_ptr_t request, + ucc_tl_ucp_task_t *task) +{ + if (request == NULL) { + return UCC_OK; + } else if (UCS_PTR_IS_ERR(request)) { + tl_error(UCC_TASK_LIB(task), "unable to complete UCX request=%p: %d", + request, UCS_PTR_STATUS(request)); + return ucs_status_to_ucc_status(UCS_PTR_STATUS(request)); + } else { + return ucs_status_to_ucc_status(ucp_request_check_status(request)); + } +} + +void ucc_tl_ucp_dpu_xgvmi_key_exchange_progress( + ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_coll_task_t *allgather_task = + task->dpu_xgvmi.allgather_task; + ucc_status_t status = allgather_task->super.status; + + if (status < 0) { + goto err; + } + if (UCC_INPROGRESS == status) { + ucc_tl_ucp_allgather_ring_progress(allgather_task); + return; + } + ucc_assert(status == UCC_OK); + + // copy from allgather recvbuf into xgvmi task + UCC_CHECK_GOTO( + ucc_tl_ucp_dpu_xgvmi_allgather_info_finalize(task), + err, status); + +out: + ucc_tl_ucp_coll_finalize(allgather_task); + task->dpu_xgvmi.allgather_task = NULL; + return; +err: + ucc_tl_ucp_dpu_xgvmi_free_task(coll_task); + tl_error(coll_task->team->context->lib, + "key exchange failure: %s", + ucc_status_string(status)); + goto out; +} + +ucc_status_t +ucc_tl_ucp_dpu_xgvmi_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h) +{ + ucc_schedule_t *schedule = NULL; + ucc_status_t status = UCC_OK; + ucc_tl_ucp_team_t *tl_team = + ucc_derived_of(team, ucc_tl_ucp_team_t); + size_t allgather_size = + sizeof(ucc_tl_ucp_allreduce_sw_host_allgather_t); + ucc_rank_t size = UCC_TL_TEAM_SIZE(tl_team); + ucc_base_coll_args_t bargs = { + .mask = 0, + .args = { + .coll_type = UCC_COLL_TYPE_ALLGATHER, + .mask = 0, + .src.info = {.buffer = NULL, + .count = allgather_size, + .datatype = UCC_DT_UINT8, + .mem_type = UCC_MEMORY_TYPE_HOST}, + .dst.info = {.buffer = NULL, + .count = allgather_size * size, + .datatype = UCC_DT_UINT8, + .mem_type = UCC_MEMORY_TYPE_HOST} + } + }; + ucc_base_coll_args_t barrier_coll_args = { + .team = team->params.team, + .args.coll_type = UCC_COLL_TYPE_BARRIER, + }; + ucc_tl_ucp_allreduce_sw_host_allgather_t *allgather_data; + ucc_tl_ucp_task_t *rdma_task; + ucc_coll_task_t *barrier_task; + + status = ucc_tl_ucp_get_schedule(tl_team, coll_args, + (ucc_tl_ucp_schedule_t **)&schedule); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + + *task_h = &schedule->super; + schedule->super.post = ucc_tl_ucp_dpu_xgvmi_start; + schedule->super.progress = NULL; + schedule->super.finalize = ucc_tl_ucp_dpu_xgvmi_finalize; + + schedule->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; + + rdma_task = ucc_tl_ucp_init_task(coll_args, team); + if (ucc_unlikely(!rdma_task)) { + tl_error(UCC_TL_TEAM_LIB(tl_team), "Couldnt allocate task"); + return UCC_ERR_NO_MEMORY; + } + + status = ucc_tl_ucp_dpu_xgvmi_task_init(coll_args, team, + rdma_task); + if (status != UCC_OK) { + tl_error(UCC_TL_TEAM_LIB(tl_team), "failed to init task: %s", + ucc_status_string(status)); + goto out; + } + + allgather_data = rdma_task->dpu_xgvmi.allgather_data; + bargs.args.src.info.buffer = allgather_data; + bargs.args.dst.info.buffer = PTR_OFFSET(allgather_data, allgather_size); + + rdma_task->super.post = ucc_tl_ucp_dpu_xgvmi_rdma_task_post; + rdma_task->super.finalize = ucc_tl_ucp_dpu_xgvmi_rdma_task_finalize; + + switch (coll_args->args.coll_type) { + case UCC_COLL_TYPE_ALLTOALL: + rdma_task->super.progress = ucc_tl_ucp_dpu_xgvmi_rdma_progress_alltoall; + break; + case UCC_COLL_TYPE_ALLGATHER: + rdma_task->super.progress = ucc_tl_ucp_dpu_xgvmi_rdma_progress_allgather; + break; + default: + tl_error(UCC_TL_TEAM_LIB(tl_team), "coll_type %s is not supported", + ucc_coll_type_str(coll_args->args.coll_type)); + break; + } + + rdma_task->dpu_xgvmi.requests = ucc_malloc(sizeof(ucs_status_ptr_t) * size); + + UCC_CHECK_GOTO(ucc_tl_ucp_allgather_ring_init(&bargs, team, + &rdma_task->dpu_xgvmi.allgather_task), + free_rdma_task, status); + + status = ucc_tl_ucp_coll_init(&barrier_coll_args, team, + &barrier_task); + if (status < 0) { + tl_error(UCC_TL_TEAM_LIB(tl_team), + "failure during sliding window barrier init: %s", + ucc_status_string(status)); + goto free_allgather_task; + } + + UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, &rdma_task->super), out, status); + UCC_CHECK_GOTO(ucc_event_manager_subscribe(&schedule->super, + UCC_EVENT_SCHEDULE_STARTED, + &rdma_task->super, + ucc_task_start_handler), + free_barrier_task, status); + + UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, barrier_task), out, status); + UCC_CHECK_GOTO(ucc_event_manager_subscribe(&rdma_task->super, + UCC_EVENT_COMPLETED, + barrier_task, + ucc_task_start_handler), + free_barrier_task, status); + + return status; + +free_barrier_task: + ucc_tl_ucp_coll_finalize(barrier_task); +free_allgather_task: + ucc_tl_ucp_coll_finalize(rdma_task->dpu_xgvmi.allgather_task); +free_rdma_task: + ucc_tl_ucp_dpu_xgvmi_free_task(&rdma_task->super); +out: + ucc_tl_ucp_put_schedule(schedule); + return status; +} diff --git a/src/components/tl/ucp/tl_ucp_dpu_offload.h b/src/components/tl/ucp/tl_ucp_dpu_offload.h index 8416331621..5d53b810f6 100644 --- a/src/components/tl/ucp/tl_ucp_dpu_offload.h +++ b/src/components/tl/ucp/tl_ucp_dpu_offload.h @@ -49,5 +49,22 @@ ucc_status_t ucc_tl_ucp_allreduce_sliding_window_register( ucp_context_h ucp_context, ucc_tl_ucp_team_t *tl_team, struct ucc_tl_ucp_allreduce_sw_export_buf *ebuf, void *packed_memh); +void ucc_tl_ucp_dpu_xgvmi_free_rkeys( + ucc_coll_task_t *coll_task); + +ucc_status_t +ucc_tl_ucp_dpu_xgvmi_rdma_task_finalize( + ucc_coll_task_t *coll_task); + +ucc_status_t +ucc_tl_ucp_dpu_xgvmi_req_test(ucs_status_ptr_t request, + ucc_tl_ucp_task_t *task); + +void ucc_tl_ucp_dpu_xgvmi_key_exchange_progress(ucc_coll_task_t *coll_task); + +ucc_status_t +ucc_tl_ucp_dpu_xgvmi_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h); #endif