diff --git a/.github/workflows/codestyle.yaml b/.github/workflows/codestyle.yaml index 7f5abd443a..7f56636c59 100644 --- a/.github/workflows/codestyle.yaml +++ b/.github/workflows/codestyle.yaml @@ -37,7 +37,7 @@ jobs: fi fi H1="CODESTYLE|REVIEW|CORE|UTIL|TEST|API|DOCS|TOOLS|BUILD|MC|EC|SCHEDULE|TOPO" - H2="CI|CL/|TL/|MC/|EC/|UCP|SHM|NCCL|SHARP|BASIC|HIER|CUDA|CPU|EE|RCCL|ROCM|SELF|MLX5" + H2="CI|CL/|TL/|MC/|EC/|UCP|SHM|NCCL|SHARP|BASIC|HIER|DOCA_UROM|CUDA|CPU|EE|RCCL|ROCM|SELF|MLX5" if ! echo $msg | grep -qP '^Merge |^'"(($H1)|($H2))"'+: \w' then echo "Wrong header" diff --git a/Makefile.am b/Makefile.am index 5b2643c4f0..11e305342c 100644 --- a/Makefile.am +++ b/Makefile.am @@ -7,6 +7,7 @@ if !DOCS_ONLY SUBDIRS = \ src \ + contrib \ tools/info \ cmake diff --git a/config/m4/doca_urom.m4 b/config/m4/doca_urom.m4 new file mode 100644 index 0000000000..8295c946c0 --- /dev/null +++ b/config/m4/doca_urom.m4 @@ -0,0 +1,75 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# See file LICENSE for terms. +# + +AC_DEFUN([CHECK_DOCA_UROM],[ +AS_IF([test "x$doca_urom_checked" != "xyes"],[ + doca_urom_happy="no" + AC_ARG_WITH([doca_urom], + [AS_HELP_STRING([--with-doca_urom=(DIR)], [Enable the use of DOCA_UROM (default is guess).])], + [], [with_doca_urom=guess]) + AS_IF([test "x$with_doca_urom" != "xno"], + [ + save_CPPFLAGS="$CPPFLAGS" + save_LDFLAGS="$LDFLAGS" + AS_IF([test ! -z "$with_doca_urom" -a "x$with_doca_urom" != "xyes" -a "x$with_doca_urom" != "xguess"], + [ + AS_IF([test ! -d $with_doca_urom], + [AC_MSG_ERROR([Provided "--with-doca_urom=${with_doca_urom}" location does not exist])]) + check_doca_urom_dir="$with_doca_urom" + check_doca_urom_libdir="$with_doca_urom/lib64" + CPPFLAGS="-I$with_doca_urom/include $UCS_CPPFLAGS $save_CPPFLAGS" + LDFLAGS="-L$check_doca_urom_libdir $save_LDFLAGS" + ]) + AS_IF([test ! -z "$with_doca_urom_libdir" -a "x$with_doca_urom_libdir" != "xyes"], + [ + check_doca_urom_libdir="$with_doca_urom_libdir" + LDFLAGS="-L$check_doca_urom_libdir $save_LDFLAGS" + ]) + AC_CHECK_HEADERS([doca_urom.h], + [ + AC_CHECK_LIB([doca_urom], [doca_urom_service_create], + [ + doca_urom_happy="yes" + ], + [ + echo "CPPFLAGS: $CPPFLAGS" + doca_urom_happy="no" + ], [-ldoca_common -ldoca_argp -ldoca_urom]) + ], + [ + doca_urom_happy="no" + ]) + AS_IF([test "x$doca_urom_happy" = "xyes"], + [ + AS_IF([test "x$check_doca_urom_dir" != "x"], + [ + AC_MSG_RESULT([DOCA_UROM dir: $check_doca_urom_dir]) + AC_SUBST(DOCA_UROM_CPPFLAGS, "-I$check_doca_urom_dir/include/ $doca_urom_old_headers") + ]) + AS_IF([test "x$check_doca_urom_libdir" != "x"], + [ + AC_SUBST(DOCA_UROM_LDFLAGS, "-L$check_doca_urom_libdir") + ]) + AC_SUBST(DOCA_UROM_LIBADD, "-ldoca_common -ldoca_argp -ldoca_urom") + AC_DEFINE([HAVE_DOCA_UROM], 1, [Enable DOCA_UROM support]) + ], + [ + AS_IF([test "x$with_doca_urom" != "xguess"], + [ + AC_MSG_ERROR([DOCA_UROM support is requested but DOCA_UROM packages cannot be found! $CPPFLAGS $LDFLAGS]) + ], + [ + AC_MSG_WARN([DOCA_UROM not found]) + ]) + ]) + CPPFLAGS="$save_CPPFLAGS" + LDFLAGS="$save_LDFLAGS" + ], + [ + AC_MSG_WARN([DOCA_UROM was explicitly disabled]) + ]) + doca_urom_checked=yes + AM_CONDITIONAL([HAVE_DOCA_UROM], [test "x$doca_urom_happy" != xno]) +])]) diff --git a/configure.ac b/configure.ac index 246e297c90..e2416d9281 100644 --- a/configure.ac +++ b/configure.ac @@ -162,6 +162,7 @@ AS_IF([test "x$with_docs_only" = xyes], AM_CONDITIONAL([HAVE_IBVERBS],[false]) AM_CONDITIONAL([HAVE_RDMACM],[false]) AM_CONDITIONAL([HAVE_MLX5DV],[false]) + AM_CONDITIONAL([HAVE_DOCA_UROM], [false]) ], [ AM_CONDITIONAL([DOCS_ONLY], [false]) @@ -172,6 +173,7 @@ AS_IF([test "x$with_docs_only" = xyes], m4_include([config/m4/cuda.m4]) m4_include([config/m4/nccl.m4]) m4_include([config/m4/rocm.m4]) + m4_include([config/m4/doca_urom.m4]) m4_include([config/m4/rccl.m4]) m4_include([config/m4/sharp.m4]) m4_include([config/m4/mpi.m4]) @@ -205,6 +207,9 @@ AS_IF([test "x$with_docs_only" = xyes], mc_modules="${mc_modules}:rocm" fi + CHECK_DOCA_UROM + AC_MSG_RESULT([DOCA_UROM support: $doca_urom_happy]) + CHECK_GTEST AC_MSG_RESULT([GTEST support: $gtest_happy]) @@ -224,11 +229,13 @@ LDFLAGS="$LDFLAGS $UCS_LDFLAGS $UCS_LIBADD" CHECK_TL_COLL_PLUGINS AC_CONFIG_FILES([ Makefile + contrib/Makefile src/Makefile src/ucc/api/ucc_version.h src/core/ucc_version.c src/components/cl/basic/Makefile src/components/cl/hier/Makefile + src/components/cl/doca_urom/Makefile src/components/mc/cpu/Makefile src/components/mc/cuda/Makefile src/components/ec/cpu/Makefile @@ -265,6 +272,7 @@ AC_MSG_NOTICE([ C++ compiler: ${CXX} ${CXXFLAGS} ${BASE_CXXFLAGS}]) AS_IF([test "x$cuda_happy" = "xyes"],[ AC_MSG_NOTICE([ NVCC gencodes: ${NVCC_ARCH}]) ]) +AC_MSG_NOTICE([ DOCA UROM enabled: ${doca_urom_happy}]) AC_MSG_NOTICE([ Perftest: ${mpi_enable}]) AC_MSG_NOTICE([ Gtest: ${gtest_enable}]) AC_MSG_NOTICE([ MC modules: <$(echo ${mc_modules}|tr ':' ' ') >]) diff --git a/contrib/Makefile.am b/contrib/Makefile.am new file mode 100644 index 0000000000..1e41aa9be1 --- /dev/null +++ b/contrib/Makefile.am @@ -0,0 +1,22 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# + +if HAVE_DOCA_UROM + +sources = \ + doca_urom_ucc_plugin/common/urom_ucc.h \ + doca_urom_ucc_plugin/dpu/worker_ucc_p2p.c \ + doca_urom_ucc_plugin/dpu/worker_ucc.h \ + doca_urom_ucc_plugin/dpu/worker_ucc.c + +plugindir = $(moduledir)/doca_plugins + +plugin_LTLIBRARIES = libucc_doca_urom_plugin.la +libucc_doca_urom_plugin_la_SOURCES = $(sources) +libucc_doca_urom_plugin_la_CPPFLAGS = $(AM_CPPFLAGS) $(BASE_CPPFLAGS) $(UCX_CPPFLAGS) $(DOCA_UROM_CPPFLAGS) +libucc_doca_urom_plugin_la_CFLAGS = $(BASE_CFLAGS) +libucc_doca_urom_plugin_la_LDFLAGS = -version-info $(SOVERSION) --as-needed $(DOCA_UROM_LDFLAGS) +libucc_doca_urom_plugin_la_LIBADD = $(UCX_LIBADD) $(DOCA_UROM_LIBADD) $(UCC_TOP_BUILDDIR)/src/libucc.la + +endif diff --git a/contrib/doca_urom_ucc_plugin/common/urom_ucc.h b/contrib/doca_urom_ucc_plugin/common/urom_ucc.h new file mode 100644 index 0000000000..7568a84f4e --- /dev/null +++ b/contrib/doca_urom_ucc_plugin/common/urom_ucc.h @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES, ALL RIGHTS RESERVED. + * + * This software product is a proprietary product of NVIDIA CORPORATION & + * AFFILIATES (the "Company") and all right, title, and interest in and to the + * software product, including all associated intellectual property rights, are + * and shall remain exclusively with the Company. + * + * This software product is governed by the End User License Agreement + * provided with the software product. + * + */ + +#ifndef UROM_UCC_H_ +#define UROM_UCC_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* UCC serializing next raw, iter points to the offset place and returns the buffer start */ +#define urom_ucc_serialize_next_raw(_iter, _type, _offset) \ + ({ \ + _type *_result = (_type *)(*(_iter)); \ + *(_iter) = UCS_PTR_BYTE_OFFSET(*(_iter), _offset); \ + _result; \ + }) + +/* UCC command types */ +enum urom_worker_ucc_cmd_type { + UROM_WORKER_CMD_UCC_LIB_CREATE, /* UCC library create command */ + UROM_WORKER_CMD_UCC_LIB_DESTROY, /* UCC library destroy command */ + UROM_WORKER_CMD_UCC_CONTEXT_CREATE, /* UCC context create command */ + UROM_WORKER_CMD_UCC_CONTEXT_DESTROY, /* UCC context destroy command */ + UROM_WORKER_CMD_UCC_TEAM_CREATE, /* UCC team create command */ + UROM_WORKER_CMD_UCC_COLL, /* UCC collective create command */ + UROM_WORKER_CMD_UCC_CREATE_PASSIVE_DATA_CHANNEL, /* UCC passive data channel command */ +}; + +/* + * UCC library create command structure + * + * Input parameters for creating the library handle. The semantics of the parameters are defined by ucc.h + * On successful completion of urom_worker_cmd_ucc_lib_create, + * The UROM worker will generate a notification on the notification queue. This + * notification has reference to local library handle on the worker. The + * implementation can choose to create shadow handles or safely pack the + * library handle on the BlueCC worker to the AEU. + */ +struct urom_worker_cmd_ucc_lib_create { + void *params; /* UCC library parameters */ +}; + +/* UCC context create command structure */ +struct urom_worker_cmd_ucc_context_create { + union { + int64_t start; /* The started index */ + int64_t *array; /* Set stride to <= 0 if array is used */ + }; + int64_t stride; /* Set number of strides */ + int64_t size; /* Set stride size */ + void *base_va; /* Shared buffer address */ + uint64_t len; /* Buffer length */ +}; + +/* UCC passive data channel command structure */ +struct urom_worker_cmd_ucc_pass_dc { + void *ucp_addr; /* UCP worker address on host */ + size_t addr_len; /* UCP worker address length */ +}; + +/* UCC context destroy command structure */ +struct urom_worker_cmd_ucc_context_destroy { + void *context_h; /* UCC context pointer */ +}; + +/* UCC team create command structure */ +struct urom_worker_cmd_ucc_team_create { + int64_t start; /* Team start index */ + int64_t stride; /* Number of strides */ + int64_t size; /* Stride size */ + void *context_h; /* UCC context */ +}; + +/* UCC team destroy command structure */ +struct urom_worker_cmd_ucc_team_destroy { + void *team; /* UCC team to destroy */ +}; + +/* UCC collective command structure */ +struct urom_worker_cmd_ucc_coll { + void *coll_args; /* Collective arguments */ + void *team; /* UCC team */ + int use_xgvmi; /* If operation uses XGVMI */ + void *work_buffer; /* Work buffer */ + size_t work_buffer_size; /* Buffer size */ + size_t team_size; /* Team size */ +}; + +/* UROM UCC worker command structure */ +struct urom_worker_ucc_cmd { + uint64_t cmd_type; /* Type of command as defined by urom_worker_ucc_cmd_type */ + uint64_t dpu_worker_id; /* DPU worker id as part of the team */ + union { + struct urom_worker_cmd_ucc_lib_create lib_create_cmd; /* Lib create command */ + struct urom_worker_cmd_ucc_context_create context_create_cmd; /* Context create command */ + struct urom_worker_cmd_ucc_context_destroy context_destroy_cmd; /* Context destroy command */ + struct urom_worker_cmd_ucc_team_create team_create_cmd; /* Team create command */ + struct urom_worker_cmd_ucc_team_destroy team_destroy_cmd; /* Team destroy command */ + struct urom_worker_cmd_ucc_coll coll_cmd; /* UCC collective command */ + struct urom_worker_cmd_ucc_pass_dc pass_dc_create_cmd; /* Passive data channel command */ + }; +}; + +/* UCC notification types */ +enum urom_worker_ucc_notify_type { + UROM_WORKER_NOTIFY_UCC_LIB_CREATE_COMPLETE, /* Create UCC library on DPU notification */ + UROM_WORKER_NOTIFY_UCC_LIB_DESTROY_COMPLETE, /* Destroy UCC library on DPU notification */ + UROM_WORKER_NOTIFY_UCC_CONTEXT_CREATE_COMPLETE, /* Create UCC context on DPU notification */ + UROM_WORKER_NOTIFY_UCC_CONTEXT_DESTROY_COMPLETE, /* Destroy UCC context on DPU notification */ + UROM_WORKER_NOTIFY_UCC_TEAM_CREATE_COMPLETE, /* Create UCC team on DPU notification */ + UROM_WORKER_NOTIFY_UCC_COLLECTIVE_COMPLETE, /* UCC collective completion notification */ + UROM_WORKER_NOTIFY_UCC_PASSIVE_DATA_CHANNEL_COMPLETE, /* UCC data channel completion notification */ +}; + +/* UCC context create notification structure */ +struct urom_worker_ucc_notify_context_create { + void *context; /* Pointer to UCC context */ +}; + +/* UCC team create notification structure */ +struct urom_worker_ucc_notify_team_create { + void *team; /* Pointer to UCC team */ +}; + +/* UCC collective notification structure */ +struct urom_worker_ucc_notify_collective { + ucc_status_t status; /* UCC collective status */ +}; + +/* UCC passive data channel notification structure */ +struct urom_worker_ucc_notify_pass_dc { + ucc_status_t status; /* UCC data channel status */ +}; + +/* UROM UCC worker notification structure */ +struct urom_worker_notify_ucc { + uint64_t notify_type; /* Notify type as defined by urom_worker_ucc_notify_type */ + uint64_t dpu_worker_id; /* DPU worker id */ + union { + struct urom_worker_ucc_notify_context_create context_create_nqe; /* Context create notification */ + struct urom_worker_ucc_notify_team_create team_create_nqe; /* Team create notification */ + struct urom_worker_ucc_notify_collective coll_nqe; /* Collective notification */ + struct urom_worker_ucc_notify_pass_dc pass_dc_nqe; /* Passive data channel notification */ + }; +}; + +typedef struct ucc_worker_key_buf { + size_t src_len; + size_t dst_len; + char rkeys[1024]; +} ucc_worker_key_buf; + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif /* UROM_UCC_H_ */ diff --git a/contrib/doca_urom_ucc_plugin/dpu/worker_ucc.c b/contrib/doca_urom_ucc_plugin/dpu/worker_ucc.c new file mode 100644 index 0000000000..321be660f4 --- /dev/null +++ b/contrib/doca_urom_ucc_plugin/dpu/worker_ucc.c @@ -0,0 +1,2180 @@ +/* + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES, ALL RIGHTS RESERVED. + * + * This software product is a proprietary product of NVIDIA CORPORATION & + * AFFILIATES (the "Company") and all right, title, and interest in and to the + * software product, including all associated intellectual property rights, are + * and shall remain exclusively with the Company. + * + * This software product is governed by the End User License Agreement + * provided with the software product. + * + */ + +#define _GNU_SOURCE + +#include +#include +#include +#include + +#include +#include + +#include + +#include "worker_ucc.h" +#include "../common/urom_ucc.h" + +DOCA_LOG_REGISTER(UROM::WORKER::UCC); + +static uint64_t plugin_version = 0x01; /* UCC plugin DPU version */ +static volatile uint64_t *queue_front; /* Front queue node */ +static volatile uint64_t *queue_tail; /* Tail queue node */ +static volatile uint64_t *queue_size; /* Queue size */ +static int ucc_component_enabled; /* Shared variable between UCC worker threads */ +static pthread_t context_progress_thread; /* UCC progress thread context */ +static uint64_t queue_lock = 0; /* Threads queue lock */ +static pthread_t *progress_thread = NULL; /* Progress threads array */ + +/* UCC opts structure */ +struct worker_ucc_opts worker_ucc_opts = { + .num_progress_threads = 1, + .ppw = 32, + .tpp = 1, + .list_size = 64, + .num_psync = 128, + .dpu_worker_binding_stride = 1, +}; + +/* Progress thread arguments structure */ +struct thread_args { + uint64_t thread_id; /* Progress thread id */ + struct urom_worker_ucc *ucc_worker; /* UCC worker context */ +}; + +// determine number of cores by counting the number of lines containing +// "processor" in /proc/cpuinfo +int get_ncores() +{ + FILE *fptr; + char str[100]; + char *pos; + int index, count = 0; + static int core_count = 0; + + // just read the file once and return the stored value on subsequent calls + if (core_count != 0) { + return core_count; + } + + fptr = fopen("/proc/cpuinfo", "rb"); + + if (fptr == NULL) { + printf("Failed to open /proc/cpuinfo\n"); + exit(EXIT_FAILURE); + } + + while ((fgets(str, 100, fptr)) != NULL) { + index = 0; + while ((pos = strstr(str + index, "processor")) != NULL) { + index = (pos - str) + 1; + count++; + } + } + + fclose(fptr); + core_count = count; + return count; +} + +void dpu_thread_set_affinity_specific_core(int core_id) +{ + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + + if (core_id >=0 && core_id < get_ncores()) { + CPU_SET(core_id, &cpuset); + pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset); + } else { + printf("bad core id: %d\n", core_id); + exit(-1); + } +} + +void dpu_thread_set_affinity(int thread_id) +{ + int coreid = thread_id; + int do_stride = worker_ucc_opts.dpu_worker_binding_stride; + int num_threads = worker_ucc_opts.num_progress_threads; + int num_cores; + int stride; + cpu_set_t cpuset; + + num_cores = get_ncores(); + stride = num_cores/num_threads; + + CPU_ZERO(&cpuset); + + if(do_stride) { + stride = do_stride; + if (num_threads % 2 != 0) { + stride = 1; + } + coreid *= stride; + } + + if (coreid >=0 && coreid < num_cores) { + CPU_SET(coreid, &cpuset); + pthread_setaffinity_np(progress_thread[thread_id], + sizeof(cpuset), &cpuset); + } +} + +/* + * Find available queue element + * + * @ctx_id [in]: UCC context id + * @ucc_worker [in]: UCC command descriptor + * @ret_qe [out]: set available queue element + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t find_qe_slot(uint64_t ctx_id, struct urom_worker_ucc *ucc_worker, struct ucc_queue_element **ret_qe) +{ + int thread_id = ctx_id % worker_ucc_opts.num_progress_threads; + uint64_t next = (queue_tail[thread_id] + 1) % worker_ucc_opts.list_size; + int curr = queue_tail[thread_id]; + + if (next == queue_front[thread_id]) { + *ret_qe = NULL; + return DOCA_ERROR_FULL; + } + + *ret_qe = &ucc_worker->queue[thread_id][curr]; + if ((*ret_qe)->in_use != 0) { + *ret_qe = NULL; + return DOCA_ERROR_BAD_STATE; + } + queue_tail[thread_id] = next; + return DOCA_SUCCESS; +} + +/* + * Open UCC worker plugin + * + * @ctx [in]: DOCA UROM worker context + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t urom_worker_ucc_open(struct urom_worker_ctx *ctx) +{ + uint64_t i, j; + doca_error_t result; + ucs_status_t status; + ucp_params_t ucp_params; + ucp_config_t *ucp_config; + ucp_worker_params_t worker_params; + struct urom_worker_ucc *ucc_worker; + + dpu_thread_set_affinity_specific_core(get_ncores() - 1); + + if (ctx == NULL) + return DOCA_ERROR_INVALID_VALUE; + + ucc_worker = calloc(1, sizeof(*ucc_worker)); + if (ucc_worker == NULL) { + DOCA_LOG_ERR("Failed to allocate UCC worker context"); + return DOCA_ERROR_NO_MEMORY; + } + + if (worker_ucc_opts.num_progress_threads < MIN_THREADS) { + worker_ucc_opts.num_progress_threads = MIN_THREADS; + DOCA_LOG_WARN("Number of threads for UCC Offload must be 1 or more, set to 1"); + } + + ucc_worker->ctx_id = 0; + ucc_worker->nr_connections = 0; + ucc_worker->ucc_data = calloc(worker_ucc_opts.ppw * worker_ucc_opts.tpp, sizeof(struct ucc_data)); + if (ucc_worker->ucc_data == NULL) { + DOCA_LOG_ERR("Failed to allocate UCC worker context"); + result = DOCA_ERROR_NO_MEMORY; + goto ucc_free; + } + + ucc_worker->queue = (struct ucc_queue_element **)malloc(sizeof(struct ucc_queue_element *) * + worker_ucc_opts.num_progress_threads); + if (ucc_worker->queue == NULL) { + DOCA_LOG_ERR("Failed to allocate UCC elements queue"); + result = DOCA_ERROR_NO_MEMORY; + goto ucc_data_free; + } + for (i = 0; i < worker_ucc_opts.num_progress_threads; i++) { + ucc_worker->queue[i] = calloc(worker_ucc_opts.list_size, sizeof(struct ucc_queue_element)); + if (ucc_worker->queue[i] == NULL) { + DOCA_LOG_ERR("Failed to allocate queue elements"); + result = DOCA_ERROR_NO_MEMORY; + goto queue_free; + } + } + + queue_front = (volatile uint64_t *)calloc(worker_ucc_opts.num_progress_threads, sizeof(uint64_t)); + if (queue_front == NULL) { + result = DOCA_ERROR_NO_MEMORY; + goto queue_free; + } + + queue_tail = (volatile uint64_t *)calloc(worker_ucc_opts.num_progress_threads, sizeof(uint64_t)); + if (queue_tail == NULL) { + result = DOCA_ERROR_NO_MEMORY; + goto queue_front_free; + } + + queue_size = (volatile uint64_t *)calloc(worker_ucc_opts.num_progress_threads, sizeof(uint64_t)); + if (queue_size == NULL) { + result = DOCA_ERROR_NO_MEMORY; + goto queue_tail_free; + } + + status = ucp_config_read(NULL, NULL, &ucp_config); + if (status != UCS_OK) { + DOCA_LOG_ERR("Failed to read UCP config"); + goto queue_size_free; + } + + status = ucp_config_modify(ucp_config, "PROTO_ENABLE", "y"); + if (status != UCS_OK) { + DOCA_LOG_ERR("Failed to read UCP config"); + ucp_config_release(ucp_config); + goto queue_size_free; + } + + ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES; + ucp_params.features = UCP_FEATURE_TAG | UCP_FEATURE_RMA | UCP_FEATURE_AMO64 | UCP_FEATURE_EXPORTED_MEMH; + status = ucp_init(&ucp_params, ucp_config, &ucc_worker->ucp_data.ucp_context); + ucp_config_release(ucp_config); + if (status != UCS_OK) { + DOCA_LOG_ERR("Failed to initialized UCP"); + result = DOCA_ERROR_DRIVER; + goto queue_size_free; + } + + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.thread_mode = UCS_THREAD_MODE_MULTI; + status = ucp_worker_create(ucc_worker->ucp_data.ucp_context, &worker_params, &ucc_worker->ucp_data.ucp_worker); + if (status != UCS_OK) { + DOCA_LOG_ERR("Unable to create ucp worker"); + result = DOCA_ERROR_DRIVER; + goto ucp_cleanup; + } + + ucc_worker->ucp_data.eps = kh_init(ep); + if (ucc_worker->ucp_data.eps == NULL) { + DOCA_LOG_ERR("Failed to init EP hashtable map"); + result = DOCA_ERROR_DRIVER; + goto worker_destroy; + } + + ucc_worker->ucp_data.memh = kh_init(memh); + if (ucc_worker->ucp_data.memh == NULL) { + DOCA_LOG_ERR("Failed to init memh hashtable map"); + result = DOCA_ERROR_DRIVER; + goto eps_destroy; + } + + ucc_worker->ucp_data.rkeys = kh_init(rkeys); + if (ucc_worker->ucp_data.rkeys == NULL) { + DOCA_LOG_ERR("Failed to init rkeys hashtable map"); + result = DOCA_ERROR_DRIVER; + goto memh_destroy; + } + + ucc_worker->ids = kh_init(ctx_id); + if (ucc_worker->ids == NULL) { + DOCA_LOG_ERR("Failed to init ids hashtable map"); + result = DOCA_ERROR_DRIVER; + goto rkeys_destroy; + } + + ucc_worker->super = ctx; + ucc_worker->list_lock = 0; + ucc_component_enabled = 1; + ucs_list_head_init(&ucc_worker->completed_reqs); + + ctx->plugin_ctx = ucc_worker; + DOCA_LOG_INFO("UCC worker open flow is done"); + return DOCA_SUCCESS; + +rkeys_destroy: + kh_destroy(rkeys, ucc_worker->ucp_data.rkeys); +memh_destroy: + kh_destroy(memh, ucc_worker->ucp_data.memh); +eps_destroy: + kh_destroy(ep, ucc_worker->ucp_data.eps); +worker_destroy: + ucp_worker_destroy(ucc_worker->ucp_data.ucp_worker); +ucp_cleanup: + ucp_cleanup(ucc_worker->ucp_data.ucp_context); +queue_size_free: + free((void *)queue_size); +queue_tail_free: + free((void *)queue_tail); +queue_front_free: + free((void *)queue_front); +queue_free: + for (j = 0; j < i; j++) + free(ucc_worker->queue[j]); + free(ucc_worker->queue); +ucc_data_free: + free(ucc_worker->ucc_data); +ucc_free: + free(ucc_worker); + return result; +} + +static void ucc_worker_join_and_free_threads() +{ + uint64_t i; + if (progress_thread) { + for (i = 0; i < worker_ucc_opts.num_progress_threads; i++) { + pthread_join(progress_thread[i], NULL); + } + free(progress_thread); + progress_thread = NULL; + } +} + +/* + * Close UCC worker plugin + * + * @worker_ctx [in]: DOCA UROM worker context + */ +static void urom_worker_ucc_close(struct urom_worker_ctx *worker_ctx) +{ + uint64_t i; + struct urom_worker_ucc *ucc_worker = worker_ctx->plugin_ctx; + + if (worker_ctx == NULL) + return; + + ucc_component_enabled = 0; + + ucc_worker_join_and_free_threads(); + + /* Destroy hash tables */ + kh_destroy(rkeys, ucc_worker->ucp_data.rkeys); + kh_destroy(memh, ucc_worker->ucp_data.memh); + kh_destroy(ep, ucc_worker->ucp_data.eps); + kh_destroy(ctx_id, ucc_worker->ids); + + /* UCP cleanup */ + ucp_worker_destroy(ucc_worker->ucp_data.ucp_worker); + ucp_cleanup(ucc_worker->ucp_data.ucp_context); + + /* UCC worker resources destroy */ + free((void *)queue_size); + free((void *)queue_tail); + free((void *)queue_front); + free(ucc_worker->ucc_data); + + /* Queue elements destroy */ + for (i = 0; i < worker_ucc_opts.num_progress_threads; i++) + free(ucc_worker->queue[i]); + + free(ucc_worker->queue); + + /* UCC worker destroy */ + free(ucc_worker); +} + +/* + * Unpacking UCC worker command + * + * @packed_cmd [in]: packed worker command + * @packed_cmd_len [in]: packed worker command length + * @cmd [out]: set unpacked UROM worker command + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t urom_worker_ucc_cmd_unpack(void *packed_cmd, size_t packed_cmd_len, struct urom_worker_cmd **cmd) +{ + void *ptr; + uint64_t extended_mem = 0; + ucc_coll_args_t *coll_args; + int is_count_64, is_disp_64; + struct urom_worker_ucc_cmd *ucc_cmd; + size_t team_size, count_pack_size, disp_pack_size; + + if (packed_cmd_len < sizeof(struct urom_worker_ucc_cmd)) { + DOCA_LOG_INFO("Invalid packed command length"); + return DOCA_ERROR_INVALID_VALUE; + } + + *cmd = packed_cmd; + ptr = packed_cmd + ucs_offsetof(struct urom_worker_cmd, plugin_cmd) + sizeof(struct urom_worker_ucc_cmd); + ucc_cmd = (struct urom_worker_ucc_cmd *)(*cmd)->plugin_cmd; + + switch (ucc_cmd->cmd_type) { + case UROM_WORKER_CMD_UCC_LIB_CREATE: + ucc_cmd->lib_create_cmd.params = ptr; + extended_mem += sizeof(ucc_lib_params_t); + break; + case UROM_WORKER_CMD_UCC_COLL: + coll_args = ptr; + ucc_cmd->coll_cmd.coll_args = ptr; + ptr += sizeof(ucc_coll_args_t); + extended_mem += sizeof(ucc_coll_args_t); + if (ucc_cmd->coll_cmd.work_buffer_size > 0) { + ucc_cmd->coll_cmd.work_buffer = ptr; + ptr += ucc_cmd->coll_cmd.work_buffer_size; + extended_mem += ucc_cmd->coll_cmd.work_buffer_size; + } + if (coll_args->coll_type == UCC_COLL_TYPE_ALLTOALLV || + coll_args->coll_type == UCC_COLL_TYPE_ALLGATHERV || coll_args->coll_type == UCC_COLL_TYPE_GATHERV || + coll_args->coll_type == UCC_COLL_TYPE_REDUCE_SCATTERV || + coll_args->coll_type == UCC_COLL_TYPE_SCATTERV) { + team_size = ucc_cmd->coll_cmd.team_size; + is_count_64 = ((coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS) && + (coll_args->flags & UCC_COLL_ARGS_FLAG_COUNT_64BIT)); + is_disp_64 = ((coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS) && + (coll_args->flags & UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT)); + + count_pack_size = ((is_count_64) ? sizeof(uint64_t) : sizeof(uint32_t)) * team_size; + disp_pack_size = ((is_disp_64) ? sizeof(uint64_t) : sizeof(uint32_t)) * team_size; + + coll_args->src.info_v.counts = ptr; + ptr += count_pack_size; + extended_mem += count_pack_size; + coll_args->dst.info_v.counts = ptr; + ptr += count_pack_size; + extended_mem += count_pack_size; + + coll_args->src.info_v.displacements = ptr; + ptr += disp_pack_size; + extended_mem += disp_pack_size; + coll_args->dst.info_v.displacements = ptr; + ptr += disp_pack_size; + extended_mem += disp_pack_size; + } + break; + + case UROM_WORKER_CMD_UCC_CREATE_PASSIVE_DATA_CHANNEL: + ucc_cmd->pass_dc_create_cmd.ucp_addr = ptr; + extended_mem += ucc_cmd->pass_dc_create_cmd.addr_len; + break; + } + + if ((*cmd)->len != extended_mem + sizeof(struct urom_worker_ucc_cmd)) { + DOCA_LOG_ERR("Invalid UCC command length"); + return DOCA_ERROR_INVALID_VALUE; + } + + return DOCA_SUCCESS; +} + +/* + * UCC worker safe push notification function + * + * @ucc_worker [in]: UCC worker context + * @nd [in]: UROM worker notification descriptor + */ +static void ucc_worker_safe_push_notification(struct urom_worker_ucc *ucc_worker, struct urom_worker_notif_desc *nd) +{ + uint64_t lvalue = 0; + + lvalue = ucs_atomic_cswap64(&ucc_worker->list_lock, 0, 1); + while (lvalue != 0) + lvalue = ucs_atomic_cswap64(&ucc_worker->list_lock, 0, 1); + + ucs_list_add_tail(&ucc_worker->completed_reqs, &nd->entry); + + lvalue = ucs_atomic_cswap64(&ucc_worker->list_lock, 1, 0); +} + +/* + * UCC worker host destination remove + * + * @ucc_worker [in]: UCC worker context + * @dest_id [in]: Host client dest id + */ +static void worker_ucc_dest_remove(struct urom_worker_ucc *ucc_worker, uint64_t dest_id) +{ + khint_t k; + + k = kh_get(ctx_id, ucc_worker->ids, dest_id); + if (k == kh_end(ucc_worker->ids)) { + DOCA_LOG_ERR("Destination id - %lu does not exist", dest_id); + return; + } + kh_del(ctx_id, ucc_worker->ids, k); + ucc_worker->ctx_id--; +} + +/* + * UCC worker host destinations lookup function + * + * @ucc_worker [in]: UCC worker context + * @dest_id [in]: Host client dest id + * @ctx_id [out]: Host client context id + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t worker_ucc_dest_lookup(struct urom_worker_ucc *ucc_worker, uint64_t dest_id, uint64_t *ctx_id) +{ + int ret; + khint_t k; + + k = kh_get(ctx_id, ucc_worker->ids, dest_id); + if (k != kh_end(ucc_worker->ids)) { + *ctx_id = kh_value(ucc_worker->ids, k); + return DOCA_SUCCESS; + } + + *ctx_id = ucc_worker->ctx_id; + k = kh_put(ctx_id, ucc_worker->ids, dest_id, &ret); + if (ret < 0) { + DOCA_LOG_ERR("Failed to put new context id"); + return DOCA_ERROR_DRIVER; + } + + ucc_worker->ctx_id++; + + kh_value(ucc_worker->ids, k) = *ctx_id; + DOCA_LOG_DBG("UCC worker added connection %ld", *ctx_id); + return DOCA_SUCCESS; +} + +/* + * Handle UCC library create command + * + * @ucc_worker [in]: UCC worker context + * @cmd_desc [in]: UCC command descriptor + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t urom_worker_ucc_lib_create(struct urom_worker_ucc *ucc_worker, + struct urom_worker_cmd_desc *cmd_desc) +{ + uint64_t ctx_id, i; + doca_error_t result; + ucc_status_t ucc_status; + ucc_lib_config_h lib_config; + ucc_lib_params_t *lib_params; + struct urom_worker_notify *notif; + struct urom_worker_notif_desc *nd; + struct urom_worker_notify_ucc *ucc_notif; + struct urom_worker_cmd *cmd = (struct urom_worker_cmd *)&cmd_desc->worker_cmd; + struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *)cmd->plugin_cmd; + + /* Prepare notification */ + nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); + if (nd == NULL) + return DOCA_ERROR_NO_MEMORY; + + nd->dest_id = cmd_desc->dest_id; + + notif = (struct urom_worker_notify *)&nd->worker_notif; + notif->type = cmd->type; + notif->urom_context = cmd->urom_context; + notif->len = sizeof(*ucc_notif); + notif->status = DOCA_SUCCESS; + + ucc_notif = (struct urom_worker_notify_ucc *)notif->plugin_notif; + ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_LIB_CREATE_COMPLETE; + ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; + + lib_params = ucc_cmd->lib_create_cmd.params; + lib_params->mask |= UCC_LIB_PARAM_FIELD_THREAD_MODE; + lib_params->thread_mode = UCC_THREAD_MULTIPLE; + + result = worker_ucc_dest_lookup(ucc_worker, cmd_desc->dest_id, &ctx_id); + if (result != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to lookup command destination"); + goto fail; + } + + ucc_worker->nr_connections++; + + if (ucc_worker->nr_connections > worker_ucc_opts.ppw) { + DOCA_LOG_ERR("Too many processes connected to a single worker"); + result = DOCA_ERROR_FULL; + goto dest_remove; + } + + if (UCC_OK != ucc_lib_config_read(NULL, NULL, &lib_config)) { + DOCA_LOG_ERR("Failed to read UCC lib config"); + result = DOCA_ERROR_DRIVER; + goto reduce_conn; + } + + for (i = 0; i < worker_ucc_opts.tpp; i++) { + ucc_status = ucc_init(lib_params, lib_config, &ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i].ucc_lib); + if (ucc_status != UCC_OK) { + DOCA_LOG_ERR("Failed to init UCC lib"); + result = DOCA_ERROR_DRIVER; + goto reduce_conn; + } + } + ucc_lib_config_release(lib_config); + + DOCA_LOG_DBG("Created UCC lib successfully"); + notif->status = DOCA_SUCCESS; + ucc_worker_safe_push_notification(ucc_worker, nd); + return notif->status; + +reduce_conn: + ucc_worker->nr_connections--; +dest_remove: + worker_ucc_dest_remove(ucc_worker, cmd_desc->dest_id); +fail: + DOCA_LOG_ERR("Failed to create UCC lib"); + notif->status = result; + ucc_worker_safe_push_notification(ucc_worker, nd); + return result; +} + +/* + * UCC library destroy + * + * @ucc_worker [in]: UCC worker context + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t ucc_worker_lib_destroy(struct urom_worker_ucc *ucc_worker) +{ + uint64_t j, k; + int64_t i; + ucc_status_t status; + doca_error_t result = DOCA_SUCCESS; + + ucc_component_enabled = 0; + + ucc_worker_join_and_free_threads(); + + for (j = 0; j < ucc_worker->nr_connections; j++) { + for (k = 0; k < worker_ucc_opts.tpp; k++) { + for (i = 0; i < ucc_worker->ucc_data[j*worker_ucc_opts.tpp + k].n_teams; i++) { + if (!ucc_worker->ucc_data[j*worker_ucc_opts.tpp + k].ucc_team[i]) + continue; + status = ucc_team_destroy(ucc_worker->ucc_data[j*worker_ucc_opts.tpp + k].ucc_team[i]); + if (status != UCC_OK) { + DOCA_LOG_ERR("Failed to destroy UCC team of data index %lu and team index %ld", j, i); + result = DOCA_ERROR_DRIVER; + } + free(ucc_worker->ucc_data[j*worker_ucc_opts.tpp + k].pSync); + } + if (ucc_worker->ucc_data[j*worker_ucc_opts.tpp + k].ucc_context) { + status = ucc_context_destroy(ucc_worker->ucc_data[j*worker_ucc_opts.tpp + k].ucc_context); + if (status != UCC_OK) { + DOCA_LOG_ERR("Failed to destroy UCC context of UCC data index %lu", j); + result = DOCA_ERROR_DRIVER; + } + ucc_worker->ucc_data[j*worker_ucc_opts.tpp + k].ucc_context = NULL; + } + if (ucc_worker->ucc_data[j*worker_ucc_opts.tpp + k].ucc_lib) { + status = ucc_finalize(ucc_worker->ucc_data[j*worker_ucc_opts.tpp + k].ucc_lib); + if (status != UCC_OK) { + DOCA_LOG_ERR("Failed to finalize UCC lib of UCC data index %lu", j); + result = DOCA_ERROR_DRIVER; + } + } + } + } + + return result; +} + +/* + * Handle UCC library destroy command + * + * @ucc_worker [in]: UCC worker context + * @cmd_desc [in]: UCC command descriptor + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t urom_worker_ucc_lib_destroy(struct urom_worker_ucc *ucc_worker, + struct urom_worker_cmd_desc *cmd_desc) +{ + struct urom_worker_notify *notif; + struct urom_worker_notif_desc *nd; + struct urom_worker_notify_ucc *ucc_notif; + struct urom_worker_cmd *cmd = (struct urom_worker_cmd *)&cmd_desc->worker_cmd; + struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *)cmd->plugin_cmd; + + /* Prepare notification */ + nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); + if (nd == NULL) + return DOCA_ERROR_NO_MEMORY; + + nd->dest_id = cmd_desc->dest_id; + + notif = (struct urom_worker_notify *)&nd->worker_notif; + notif->type = cmd->type; + notif->urom_context = cmd->urom_context; + notif->len = sizeof(*ucc_notif); + + ucc_notif = (struct urom_worker_notify_ucc *)notif->plugin_notif; + ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_LIB_DESTROY_COMPLETE; + ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; + + notif->status = ucc_worker_lib_destroy(ucc_worker); + ucc_worker_safe_push_notification(ucc_worker, nd); + return notif->status; +} + +/* + * Thread progress handles queue collective element + * + * @qe [in]: UCC thread queue element + * @ucc_worker [in]: UCC worker context + * @thread_id [in]: UCC thread id + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t handle_progress_thread_coll_element(struct ucc_queue_element *qe, + struct urom_worker_ucc *ucc_worker, + int thread_id) +{ + int64_t lvalue = 0; + ucc_status_t tmp_status, ucc_status = UCC_OK; + struct ucc_queue_element *qe_back; + doca_error_t status = DOCA_SUCCESS; + struct urom_worker_notify_ucc *ucc_notif; + + if (!qe->posted) { + ucc_status = ucc_collective_post(qe->coll_req); + if (UCC_OK != ucc_status) { + DOCA_LOG_ERR("Failed to post UCC collective: %s", ucc_status_string(ucc_status)); + status = DOCA_ERROR_DRIVER; + goto exit; + } + qe->posted = 1; + } + + ucc_status = ucc_collective_test(qe->coll_req); + if (ucc_status == UCC_INPROGRESS) { + ucc_context_progress(ucc_worker->ucc_data[qe->ctx_id].ucc_context); + lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); + while (lvalue != 0) + lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); + status = find_qe_slot(qe->ctx_id, ucc_worker, &qe_back); + lvalue = ucs_atomic_cswap64(&queue_lock, 1, 0); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to find queue slot for team creation"); + ucc_status = UCC_ERR_NO_RESOURCE; + goto exit; + } + *qe_back = *qe; + qe->in_use = 0; + queue_front[thread_id] = (queue_front[thread_id] + 1) % worker_ucc_opts.list_size; + return DOCA_ERROR_IN_PROGRESS; + } else if (ucc_status == UCC_OK) { + if (qe->barrier) { + pthread_barrier_wait(qe->barrier); + if (qe->nd != NULL) { + pthread_barrier_destroy(qe->barrier); + free(qe->barrier); + qe->barrier = NULL; + } + } + if (qe->key_duplicate_per_rank) { + free(qe->key_duplicate_per_rank); + qe->key_duplicate_per_rank = NULL; + } + if (qe->old_dest) { + DOCA_LOG_DBG("Putting data back to host %p with size %lu", qe->old_dest, qe->data_size); + if (qe->dest_packed_key != NULL) { + status = ucc_rma_put_host(ucc_worker->ucc_data[qe->ctx_id].local_work_buffer + + qe->data_size, + qe->old_dest, + qe->data_size, + qe->ctx_id, + qe->dest_packed_key, + ucc_worker); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to find queue slot for team creation"); + goto exit; + } + } else { + status = ucc_rma_put(ucc_worker->ucc_data[qe->ctx_id].local_work_buffer + qe->data_size, + qe->old_dest, + qe->data_size, + MAX_HOST_DEST_ID, + qe->myrank, + qe->ctx_id, + ucc_worker); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to find queue slot for team creation"); + goto exit; + } + } + } + if (qe->gwbi != NULL && qe->nd != NULL) { + free(qe->gwbi); + } + } else { + DOCA_LOG_ERR("ucc_collective_test() returned failure (%d)", ucc_status); + status = DOCA_ERROR_DRIVER; + goto exit; + } + + status = DOCA_SUCCESS; + tmp_status = ucc_collective_test(qe->coll_req); + if (tmp_status != UCC_OK) { + ucc_status = (ucc_status == UCC_OK) ? tmp_status : ucc_status; + status = DOCA_ERROR_DRIVER; + } + tmp_status = ucc_collective_finalize(qe->coll_req); + if (tmp_status != UCC_OK) { + ucc_status = (ucc_status == UCC_OK) ? tmp_status : ucc_status; + status = DOCA_ERROR_DRIVER; + } + +exit: + if (qe->nd != NULL) { + ucc_notif = (struct urom_worker_notify_ucc *)qe->nd->worker_notif.plugin_notif; + ucc_notif->coll_nqe.status = ucc_status; + qe->nd->worker_notif.status = status; + ucc_worker_safe_push_notification(ucc_worker, qe->nd); + } + queue_front[thread_id] = (queue_front[thread_id] + 1) % worker_ucc_opts.list_size; + ucs_atomic_add64(&queue_size[thread_id], -1); + qe->in_use = 0; + return status; +} + +/* + * Thread progress handles queue team element + * + * @qe [in]: UCC thread queue element + * @ucc_worker [in]: UCC worker context + * @thread_id [in]: UCC thread id + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t handle_progress_thread_team_element(struct ucc_queue_element *qe, + struct urom_worker_ucc *ucc_worker, + int thread_id) +{ + int64_t lvalue = 0; + ucc_status_t ucc_status = UCC_OK; + struct ucc_queue_element *qe_back; + doca_error_t status = DOCA_SUCCESS; + struct urom_worker_notify_ucc *ucc_notif = NULL; + + if(qe->nd != NULL) { + ucc_notif = (struct urom_worker_notify_ucc *)qe->nd->worker_notif.plugin_notif; + } + + ucc_status = ucc_team_create_test(ucc_worker->ucc_data[qe->ctx_id].ucc_team[qe->team_id]); + if (ucc_status == UCC_INPROGRESS) { + ucc_status = ucc_context_progress(ucc_worker->ucc_data[qe->ctx_id].ucc_context); + lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); + while (lvalue != 0) + lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); + status = find_qe_slot(qe->ctx_id, ucc_worker, &qe_back); + lvalue = ucs_atomic_cswap64(&queue_lock, 1, 0); + if (status != DOCA_SUCCESS) + goto exit; + *qe_back = *qe; + queue_front[thread_id] = (queue_front[thread_id] + 1) % worker_ucc_opts.list_size; + qe->in_use = 0; + return DOCA_ERROR_IN_PROGRESS; + } else if (ucc_status != UCC_OK) { + DOCA_LOG_ERR("UCC team create test failed (%d) on team %ld for ctx %ld", + ucc_status, + qe->team_id, + qe->ctx_id); + if (ucc_notif) + ucc_notif->team_create_nqe.team = NULL; + status = DOCA_ERROR_DRIVER; + } else { + if (qe->barrier) { + pthread_barrier_wait(qe->barrier); + if (qe->nd != NULL) { + pthread_barrier_destroy(qe->barrier); + free(qe->barrier); + qe->barrier = NULL; + } + } + DOCA_LOG_INFO("Finished team creation (%ld:%ld)", qe->ctx_id, qe->team_id); + if (ucc_notif) + ucc_notif->team_create_nqe.team = ucc_worker->ucc_data[qe->ctx_id].ucc_team[qe->team_id]; + status = DOCA_SUCCESS; + } + +exit: + free(qe->coll_ctx); + if (qe->nd != NULL) { + qe->nd->worker_notif.status = status; + ucc_worker_safe_push_notification(ucc_worker, qe->nd); + } + queue_front[thread_id] = (queue_front[thread_id] + 1) % worker_ucc_opts.list_size; + ucs_atomic_add64(&queue_size[thread_id], -1); + qe->in_use = 0; + return status; +} + +/* + * Progress context thread main function + * + * @arg [in]: UCC worker arg + * @return: NULL (dummy return because of pthread requirement) + */ +static void *urom_worker_ucc_progress_thread(void *arg) +{ + struct ucc_queue_element *qe; + doca_error_t status = DOCA_SUCCESS; + struct thread_args *targs = (struct thread_args *)arg; + int i, front, size, thread_id = targs->thread_id; + struct urom_worker_ucc *ucc_worker = targs->ucc_worker; + + dpu_thread_set_affinity(thread_id); + + while (ucc_component_enabled) { + size = queue_size[thread_id]; + for (i = 0; i < size; i++) { + front = queue_front[thread_id]; + qe = &ucc_worker->queue[thread_id][front]; + if (qe->in_use != 1) { + DOCA_LOG_WARN("Found queue element in queue and marked not in use"); + continue; + } + if (qe->type == UCC_WORKER_QUEUE_ELEMENT_TYPE_TEAM_CREATE) { + status = handle_progress_thread_team_element(qe, ucc_worker, thread_id); + if (status == DOCA_ERROR_IN_PROGRESS) + continue; + + if (status != DOCA_SUCCESS) + goto exit; + } else if (qe->type == UCC_WORKER_QUEUE_ELEMENT_TYPE_COLLECTIVE) { + status = handle_progress_thread_coll_element(qe, ucc_worker, thread_id); + if (status == DOCA_ERROR_IN_PROGRESS) + continue; + + if (status != DOCA_SUCCESS) + goto exit; + } else + DOCA_LOG_ERR("Unknown queue element type"); + } + sched_yield(); + } +exit: + pthread_exit(NULL); +} + +/* UCC oob allgather request */ +struct oob_allgather_req { + void *sbuf; /* Local buffer */ + void *rbuf; /* Remote buffer */ + size_t msglen; /* Message length */ + void *oob_coll_ctx; /* OOB collective context */ + int iter; /* Interation */ + int index; /* Current process index */ + int *status; /* Request status */ +}; + +/* + * UCC OOB allgather free + * + * @req [in]: allgather request data + * @return: UCC_OK on success and UCC_ERR otherwise + */ +static ucc_status_t urom_worker_ucc_oob_allgather_free(void *req) +{ + free(req); + return UCC_OK; +} + +/* + * UCC oob allgather function + * + * @sbuf [in]: local buffer to send to other processes + * @rbuf [in]: global buffer to includes other processes source buffer + * @msglen [in]: source buffer length + * @oob_coll_ctx [in]: collection info + * @req [out]: set allgather request data + * @return: UCC_OK on success and UCC_ERR otherwise + */ +static ucc_status_t urom_worker_ucc_oob_allgather(void *sbuf, void *rbuf, size_t msglen, void *oob_coll_ctx, void **req) +{ + char *recv_buf; + int index, size, i; + struct oob_allgather_req *oob_req; + struct coll_ctx *ctx = (struct coll_ctx *)oob_coll_ctx; + + size = ctx->size; + index = ctx->index; + + oob_req = malloc(sizeof(*oob_req)); + if (oob_req == NULL) { + DOCA_LOG_ERR("Failed to allocate OOB UCC request"); + return UCC_ERR_NO_MEMORY; + } + + oob_req->sbuf = sbuf; + oob_req->rbuf = rbuf; + oob_req->msglen = msglen; + oob_req->oob_coll_ctx = oob_coll_ctx; + oob_req->iter = 0; + + oob_req->status = calloc(ctx->size * 2, sizeof(int)); + *req = oob_req; + + for (i = 0; i < size; i++) { + recv_buf = (char *)rbuf + i * msglen; + ucc_recv_nb(recv_buf, msglen, i, ctx->ucc_worker, &oob_req->status[i]); + } + + for (i = 0; i < size; i++) + ucc_send_nb(sbuf, msglen, index, i, ctx->ucc_worker, &oob_req->status[i + size]); + + return UCC_OK; +} + +/* + * UCC oob allgather test function + * + * @req [in]: UCC allgather request + * @return: UCC_OK on success and UCC_ERR otherwise + */ +static ucc_status_t urom_worker_ucc_oob_allgather_test(void *req) +{ + struct coll_ctx *ctx; + struct oob_allgather_req *oob_req; + int i, probe_count, nr_done, size, nr_probes = 5; + + oob_req = (struct oob_allgather_req *)req; + ctx = (struct coll_ctx *)oob_req->oob_coll_ctx; + size = ctx->size; + + for (probe_count = 0; probe_count < nr_probes; probe_count++) { + nr_done = 0; + for (i = 0; i < size * 2; i++) { + if (oob_req->status[i] != 1 && ctx->ucc_worker->ucp_data.ucp_worker != NULL) + ucp_worker_progress(ctx->ucc_worker->ucp_data.ucp_worker); + else + ++nr_done; + } + if (nr_done == size * 2) + return UCC_OK; + } + + return UCC_INPROGRESS; +} + +/* + * Handle UCC context creation of progress threads + * + * @arg [in]: UCC worker context argument + * @return: NULL (dummy return because of pthread requirement) + */ +static void *urom_worker_ucc_ctx_progress_thread(void *arg) +{ + int ret; + uint64_t ctx_id; + char str_buf[256]; + ucc_status_t ucc_status; + doca_error_t status; + struct coll_ctx **coll_ctx; + ucc_mem_map_t **maps = NULL; + ucc_context_config_h ctx_config; + struct ctx_thread_args *args = (struct ctx_thread_args *)arg; + size_t len = args->len; + int64_t size = args->size; + int64_t start = args->start; + int64_t stride = args->stride; + int64_t myrank = args->myrank; + uint64_t dest_id = args->dest_id; + struct urom_worker_ucc *ucc_worker = args->ucc_worker; + ucc_context_params_t ctx_params = {0}; + struct urom_worker_notif_desc *nd; + struct urom_worker_notify *notif; + struct urom_worker_notify_ucc *ucc_notif; + struct thread_args *targs; + uint64_t n_threads, i, j; + + nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); + if (nd == NULL) { + status = DOCA_ERROR_NO_MEMORY; + goto exit; + } + + nd->dest_id = args->dest_id; + notif = (struct urom_worker_notify *)&nd->worker_notif; + notif->type = args->notif_type; + notif->len = sizeof(*ucc_notif); + notif->urom_context = args->urom_context; + + ucc_notif = (struct urom_worker_notify_ucc *)notif->plugin_notif; + ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_CONTEXT_CREATE_COMPLETE; + ucc_notif->dpu_worker_id = args->myrank; + + status = worker_ucc_dest_lookup(ucc_worker, dest_id, &ctx_id); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to lookup command destination"); + goto fail; + } + + maps = (ucc_mem_map_t **) calloc(worker_ucc_opts.tpp, sizeof(ucc_mem_map_t *)); + coll_ctx = (struct coll_ctx **) calloc(worker_ucc_opts.tpp, sizeof(struct coll_ctx *)); + + for (i = 0; i < worker_ucc_opts.tpp; i++) { + uint64_t thread_ctx_id = ctx_id*worker_ucc_opts.tpp + i; + + if (ucc_worker->ucc_data[thread_ctx_id].ucc_lib == NULL) { + DOCA_LOG_ERR("Attempting to create UCC context without first initializing a UCC lib"); + status = DOCA_ERROR_BAD_STATE; + goto fail; + } + + if (ucc_context_config_read(ucc_worker->ucc_data[thread_ctx_id].ucc_lib, NULL, &ctx_config) != UCC_OK) { + DOCA_LOG_ERR("Failed to read UCC context config"); + status = DOCA_ERROR_DRIVER; + goto fail; + } + + /* Set to sliding window */ + if (UCC_OK != ucc_context_config_modify(ctx_config, "tl/ucp", "TUNE", "allreduce:0-inf:@2")) { + DOCA_LOG_ERR("Failed to modify TL_UCP_TUNE UCC lib config"); + status = DOCA_ERROR_DRIVER; + goto cfg_release; + } + + /* Set estimated num of eps */ + sprintf(str_buf, "%ld", size); + ucc_status = ucc_context_config_modify(ctx_config, NULL, "ESTIMATED_NUM_EPS", str_buf); + if (ucc_status != UCC_OK) { + DOCA_LOG_ERR("UCC context config modify failed for estimated_num_eps"); + status = DOCA_ERROR_DRIVER; + goto cfg_release; + } + + ucc_worker->ucc_data[thread_ctx_id].local_work_buffer = calloc(1, len * 2); + if (ucc_worker->ucc_data[thread_ctx_id].local_work_buffer == NULL) { + DOCA_LOG_ERR("Failed to allocate local work buffer"); + status = DOCA_ERROR_NO_MEMORY; + goto cfg_release; + } + + ucc_worker->ucc_data[thread_ctx_id].pSync = calloc(worker_ucc_opts.num_psync, sizeof(long)); + if (ucc_worker->ucc_data[thread_ctx_id].pSync == NULL) { + DOCA_LOG_ERR("Failed to pSync array"); + status = DOCA_ERROR_NO_MEMORY; + goto buf_free; + } + ucc_worker->ucc_data[thread_ctx_id].len = len * 2; + + maps[i] = (ucc_mem_map_t *)calloc(3, sizeof(ucc_mem_map_t)); + if (maps[i] == NULL) { + DOCA_LOG_ERR("Failed to allocate UCC memory map array"); + status = DOCA_ERROR_NO_MEMORY; + goto psync_free; + } + + maps[i][0].address = ucc_worker->ucc_data[thread_ctx_id].local_work_buffer; + maps[i][0].len = len * 2; + maps[i][1].address = ucc_worker->ucc_data[thread_ctx_id].pSync; + maps[i][1].len = worker_ucc_opts.num_psync * sizeof(long); + + coll_ctx[i] = (struct coll_ctx *)malloc(sizeof(struct coll_ctx)); + if (coll_ctx[i] == NULL) { + DOCA_LOG_ERR("Failed to allocate UCC worker coll context"); + status = DOCA_ERROR_NO_MEMORY; + goto maps_free; + } + + if (stride <= 0) /* This is an array of ids */ + coll_ctx[i]->pids = (int64_t *)start; + else + coll_ctx[i]->start = start; + + coll_ctx[i]->stride = stride; + coll_ctx[i]->size = size; + coll_ctx[i]->index = myrank; + coll_ctx[i]->ucc_worker = ucc_worker; + + ctx_params.mask = UCC_CONTEXT_PARAM_FIELD_OOB | UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS; + ctx_params.oob.allgather = urom_worker_ucc_oob_allgather; + ctx_params.oob.req_test = urom_worker_ucc_oob_allgather_test; + ctx_params.oob.req_free = urom_worker_ucc_oob_allgather_free; + ctx_params.oob.coll_info = (void *)coll_ctx[i]; + ctx_params.oob.n_oob_eps = size; + ctx_params.oob.oob_ep = myrank; + ctx_params.mem_params.segments = maps[i]; + ctx_params.mem_params.n_segments = 2; + + ucc_status = ucc_context_create(ucc_worker->ucc_data[thread_ctx_id].ucc_lib, + &ctx_params, + ctx_config, + &ucc_worker->ucc_data[thread_ctx_id].ucc_context); + if (ucc_status != UCC_OK) { + DOCA_LOG_ERR("Failed to create UCC context"); + status = DOCA_ERROR_DRIVER; + goto coll_free; + } + ucc_context_config_release(ctx_config); + } + + if (ctx_id == 0) { + n_threads = worker_ucc_opts.num_progress_threads; + targs = calloc(n_threads, sizeof(*targs)); + if (targs == NULL) { + DOCA_LOG_ERR("Failed to create threads args"); + status = DOCA_ERROR_NO_MEMORY; + goto context_destroy; + } + progress_thread = calloc(n_threads, sizeof(*progress_thread)); + if (progress_thread == NULL) { + DOCA_LOG_ERR("Failed to create threads args"); + status = DOCA_ERROR_NO_MEMORY; + goto targs_free; + } + + DOCA_LOG_DBG("Creating [%ld] progress %lu threads", myrank, n_threads); + for (i = 0; i < n_threads; i++) { + targs[i].thread_id = i; + targs[i].ucc_worker = ucc_worker; + ret = pthread_create(&progress_thread[i], + NULL, + urom_worker_ucc_progress_thread, + (void *)&targs[i]); + if (ret != 0) { + DOCA_LOG_ERR("Failed to create progress thread"); + status = DOCA_ERROR_IO_FAILED; + goto threads_free; + } + } + } + + status = DOCA_SUCCESS; + ucc_notif->context_create_nqe.context = ucc_worker->ucc_data[ctx_id].ucc_context; + DOCA_LOG_DBG("UCC context created, ctx_id %lu, context %p", ctx_id, ucc_worker->ucc_data[ctx_id].ucc_context); + goto exit; + +threads_free: + for (j = 0; j < i; j++) + pthread_cancel(progress_thread[j]); + free(progress_thread); +targs_free: + free(targs); +context_destroy: + for(i = 0; i < worker_ucc_opts.tpp; i++) { + if(ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i].ucc_context) ucc_context_destroy(ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i].ucc_context); + } +coll_free: + for(i = 0; i < worker_ucc_opts.tpp; i++) { + if(coll_ctx[i]) free(coll_ctx[i]); + } + free(coll_ctx); +maps_free: + for(i = 0; i < worker_ucc_opts.tpp; i++) { + if(maps[i]) free(maps[i]); + } + free(maps); +psync_free: + for(i = 0; i < worker_ucc_opts.tpp; i++) { + if(ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i].pSync) + free(ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i].pSync); + } +buf_free: + for(i = 0; i < worker_ucc_opts.tpp; i++) { + if(ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i].local_work_buffer) + free(ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i].local_work_buffer); + } +cfg_release: + ucc_context_config_release(ctx_config); +fail: +exit: + nd->worker_notif.status = status; + ucc_worker_safe_push_notification(ucc_worker, nd); + free(args); + pthread_exit(NULL); +} + +/* + * Handle UCC context create command + * + * @ucc_worker [in]: UCC worker context + * @cmd_desc [in]: UCC command descriptor + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t urom_worker_ucc_context_create(struct urom_worker_ucc *ucc_worker, + struct urom_worker_cmd_desc *cmd_desc) +{ + int ret; + struct ctx_thread_args *args; + struct urom_worker_cmd *cmd = (struct urom_worker_cmd *)&cmd_desc->worker_cmd; + struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *)cmd->plugin_cmd; + struct urom_worker_notif_desc *nd; + struct urom_worker_notify_ucc *ucc_notif; + + args = calloc(1, sizeof(*args)); + if (args == NULL) + return DOCA_ERROR_NO_MEMORY; + + args->notif_type = cmd->type; + args->urom_context = cmd->urom_context; + args->start = ucc_cmd->context_create_cmd.start; + args->stride = ucc_cmd->context_create_cmd.stride; + args->size = ucc_cmd->context_create_cmd.size; + args->myrank = ucc_cmd->dpu_worker_id; + args->base_va = ucc_cmd->context_create_cmd.base_va; + args->len = ucc_cmd->context_create_cmd.len; + args->dest_id = cmd_desc->dest_id; + args->ucc_worker = ucc_worker; + + ret = pthread_create(&context_progress_thread, NULL, urom_worker_ucc_ctx_progress_thread, (void *)args); + if (ret != 0) { + nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); + if (nd == NULL) + return DOCA_ERROR_NO_MEMORY; + + nd->dest_id = cmd_desc->dest_id; + nd->worker_notif.status = DOCA_ERROR_IO_FAILED; + nd->worker_notif.type = cmd->type; + nd->worker_notif.len = sizeof(*ucc_notif); + nd->worker_notif.urom_context = cmd->urom_context; + ucc_notif = (struct urom_worker_notify_ucc *)nd->worker_notif.plugin_notif; + ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; + ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_CONTEXT_CREATE_COMPLETE; + ucc_worker_safe_push_notification(ucc_worker, nd); + return DOCA_ERROR_IO_FAILED; + } + + return DOCA_SUCCESS; +} + +/* + * Handle UCC context destroy command + * + * @ucc_worker [in]: UCC worker context + * @cmd_desc [in]: UCC command descriptor + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t urom_worker_ucc_context_destroy(struct urom_worker_ucc *ucc_worker, + struct urom_worker_cmd_desc *cmd_desc) +{ + uint64_t ctx_id, i; + doca_error_t status; + struct urom_worker_notify *notif; + struct urom_worker_notif_desc *nd; + struct urom_worker_notify_ucc *ucc_notif; + struct urom_worker_cmd *cmd = (struct urom_worker_cmd *)&cmd_desc->worker_cmd; + struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *)cmd->plugin_cmd; + + /* Prepare notification */ + nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); + if (nd == NULL) + return DOCA_ERROR_NO_MEMORY; + + nd->dest_id = cmd_desc->dest_id; + + notif = (struct urom_worker_notify *)&nd->worker_notif; + notif->type = cmd->type; + notif->urom_context = cmd->urom_context; + notif->len = sizeof(*ucc_notif); + + ucc_notif = (struct urom_worker_notify_ucc *)notif->plugin_notif; + ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_CONTEXT_DESTROY_COMPLETE; + ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; + + status = worker_ucc_dest_lookup(ucc_worker, cmd_desc->dest_id, &ctx_id); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to lookup command destination"); + goto exit; + } + + for (i = 0; i < worker_ucc_opts.tpp; i++) { + uint64_t thread_ctx_id = ctx_id*worker_ucc_opts.tpp + i; + + if (ucc_worker->ucc_data[thread_ctx_id].ucc_context) { + if (ucc_context_destroy(ucc_worker->ucc_data[thread_ctx_id].ucc_context) != UCC_OK) { + DOCA_LOG_ERR("Failed to destroy UCC context"); + status = DOCA_ERROR_DRIVER; + goto exit; + } + ucc_worker->ucc_data[thread_ctx_id].ucc_context = NULL; + } + } + + status = DOCA_SUCCESS; +exit: + notif->status = status; + ucc_worker_safe_push_notification(ucc_worker, nd); + return status; +} + +/* + * Handle UCC team command + * + * @ucc_worker [in]: UCC worker context + * @cmd_desc [in]: UCC command descriptor + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t urom_worker_ucc_team_create(struct urom_worker_ucc *ucc_worker, + struct urom_worker_cmd_desc *cmd_desc) +{ + uint64_t ctx_id, i; + ucc_ep_map_t map; + doca_error_t status; + ucc_status_t ucc_status; + size_t curr_team = 0; + struct coll_ctx *coll_ctx; + struct ucc_queue_element *qe; + ucc_team_params_t team_params; + struct urom_worker_notify *notif; + struct urom_worker_notif_desc *nd; + struct urom_worker_notify_ucc *ucc_notif; + struct urom_worker_cmd *cmd = (struct urom_worker_cmd *)&cmd_desc->worker_cmd; + struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *)cmd->plugin_cmd; + pthread_barrier_t *barrier; + uint64_t lvalue; + + /* Prepare notification */ + nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); + if (nd == NULL) + return DOCA_ERROR_NO_MEMORY; + + nd->dest_id = cmd_desc->dest_id; + + notif = (struct urom_worker_notify *)&nd->worker_notif; + notif->type = cmd->type; + notif->urom_context = cmd->urom_context; + notif->len = sizeof(*ucc_notif); + ucc_notif = (struct urom_worker_notify_ucc *)notif->plugin_notif; + ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_TEAM_CREATE_COMPLETE; + ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; + + status = worker_ucc_dest_lookup(ucc_worker, cmd_desc->dest_id, &ctx_id); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to lookup command destination"); + goto exit; + } + + barrier = malloc(sizeof(pthread_barrier_t)); + pthread_barrier_init(barrier, NULL, worker_ucc_opts.tpp); + + for (i = 0; i < worker_ucc_opts.tpp; i++) { + uint64_t thread_ctx_id = ctx_id*worker_ucc_opts.tpp + i; + + curr_team = ucc_worker->ucc_data[thread_ctx_id].n_teams; + if (ucc_worker->ucc_data[thread_ctx_id].ucc_context == NULL || + ucc_cmd->team_create_cmd.context_h != ucc_worker->ucc_data[ctx_id].ucc_context) { + DOCA_LOG_ERR("Attempting to create UCC team over non-existent context"); + status = DOCA_ERROR_INVALID_VALUE; + goto exit; + } + + if (ucc_cmd->team_create_cmd.stride <= 0) { + map.type = UCC_EP_MAP_ARRAY; + map.ep_num = ucc_cmd->team_create_cmd.size; + map.array.map = (void *)ucc_cmd->team_create_cmd.start; + map.array.elem_size = 8; + } else { + map.type = UCC_EP_MAP_STRIDED; + map.ep_num = ucc_cmd->team_create_cmd.size; + map.strided.start = ucc_cmd->team_create_cmd.start; + map.strided.stride = ucc_cmd->team_create_cmd.stride; + } + + team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_TEAM_SIZE | UCC_TEAM_PARAM_FIELD_EP_MAP | + UCC_TEAM_PARAM_FIELD_EP_RANGE; + team_params.ep = ucc_cmd->dpu_worker_id; + team_params.ep_map = map; + team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG; + team_params.team_size = ucc_cmd->team_create_cmd.size; + + coll_ctx = (struct coll_ctx *)malloc(sizeof(*coll_ctx)); + if (coll_ctx == NULL) { + DOCA_LOG_ERR("Failed to allocate collective context"); + status = DOCA_ERROR_NO_MEMORY; + goto exit; + } + + coll_ctx->start = ucc_cmd->team_create_cmd.start; + coll_ctx->stride = ucc_cmd->team_create_cmd.stride; + coll_ctx->size = ucc_cmd->team_create_cmd.size; + coll_ctx->index = ucc_cmd->dpu_worker_id; + coll_ctx->ucc_worker = ucc_worker; + + if (ucc_worker->ucc_data[thread_ctx_id].ucc_team == NULL) { + ucc_worker->ucc_data[thread_ctx_id].ucc_team = malloc(sizeof(ucc_worker->ucc_data[thread_ctx_id].ucc_team)); + if (ucc_worker->ucc_data[thread_ctx_id].ucc_team == NULL) { + status = DOCA_ERROR_NO_MEMORY; + goto coll_free; + } + } + + ucc_status = ucc_team_create_post(&ucc_worker->ucc_data[thread_ctx_id].ucc_context, + 1, + &team_params, + &ucc_worker->ucc_data[thread_ctx_id].ucc_team[curr_team]); + + if (ucc_status != UCC_OK) { + DOCA_LOG_ERR("ucc_team_create_post() failed"); + status = DOCA_ERROR_DRIVER; + goto team_free; + } + ucc_worker->ucc_data[thread_ctx_id].n_teams++; + + lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); + while (lvalue != 0) + lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); + status = find_qe_slot(thread_ctx_id, ucc_worker, &qe); + lvalue = ucs_atomic_cswap64(&queue_lock, 1, 0); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to find queue slot for team creation"); + goto team_free; + } + + qe->type = UCC_WORKER_QUEUE_ELEMENT_TYPE_TEAM_CREATE; + qe->coll_ctx = coll_ctx; + qe->dest_id = cmd_desc->dest_id; + qe->ctx_id = thread_ctx_id; + qe->team_id = curr_team; + qe->myrank = ucc_cmd->dpu_worker_id; + qe->in_use = 1; + qe->barrier = barrier; + if (i == 0) + qe->nd = nd; + else + qe->nd = NULL; + ucs_atomic_add64(&queue_size[thread_ctx_id % worker_ucc_opts.num_progress_threads], 1); + + continue; + +team_free: + free(ucc_worker->ucc_data[thread_ctx_id].ucc_team); +coll_free: + free(coll_ctx); + goto exit; + } + + return DOCA_SUCCESS; + +exit: + notif->status = status; + ucc_worker_safe_push_notification(ucc_worker, nd); + return status; +} + +size_t urom_worker_get_dt_size(ucc_datatype_t dt) +{ + size_t size_mod = 8; + switch (dt) { + case UCC_DT_INT8: + case UCC_DT_UINT8: + size_mod = sizeof(char); + break; + case UCC_DT_INT32: + case UCC_DT_UINT32: + case UCC_DT_FLOAT32: + size_mod = sizeof(int); + break; + case UCC_DT_INT64: + case UCC_DT_UINT64: + case UCC_DT_FLOAT64: + size_mod = sizeof(uint64_t); + break; + case UCC_DT_INT128: + case UCC_DT_UINT128: + case UCC_DT_FLOAT128: + size_mod = sizeof(__int128_t); + break; + default: + break; + } + return size_mod; +} + + +static doca_error_t post_nthreads_colls( + uint64_t ctx_id, struct urom_worker_ucc *ucc_worker, ucc_coll_args_t *coll_args, + ucc_team_h ucc_team, uint64_t myrank, int in_place, + ucc_tl_ucp_allreduce_sw_global_work_buf_info_t *gwbi, + struct urom_worker_notif_desc *nd, + struct urom_worker_cmd_desc *cmd_desc, + struct urom_worker_notify *notif, + ucc_worker_key_buf * key_duplicate_per_rank) +{ + pthread_barrier_t *barrier = NULL; + int64_t team_idx = 0; + ucc_coll_req_h coll_req; + struct ucc_queue_element *qe; + ucc_status_t ucc_status; + doca_error_t status = DOCA_SUCCESS; + size_t i; + uint64_t lvalue; + int64_t j; + + size_t threads = worker_ucc_opts.tpp; + size_t src_count = coll_args->src.info.count; + size_t dst_count = coll_args->dst.info.count; + size_t src_thread_count = src_count / threads; + size_t dst_thread_count = dst_count / threads; + size_t src_thread_size = + src_thread_count * urom_worker_get_dt_size(coll_args->src.info.datatype); + size_t dst_thread_size = + dst_thread_count * urom_worker_get_dt_size(coll_args->dst.info.datatype); + void *src_buf = coll_args->src.info.buffer; + void *dst_buf = coll_args->dst.info.buffer; + + coll_args->mask |= UCC_COLL_ARGS_FIELD_FLAGS | UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER; + coll_args->flags |= UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS; + + coll_args->global_work_buffer = gwbi; + + barrier = malloc(sizeof(pthread_barrier_t)); + pthread_barrier_init(barrier, NULL, worker_ucc_opts.tpp); + + for(i = 0; i < threads; i++) { + uint64_t thread_ctx_id = ctx_id*worker_ucc_opts.tpp + i; + + gwbi = malloc(sizeof(ucc_tl_ucp_allreduce_sw_global_work_buf_info_t)); + if (gwbi == NULL) { + DOCA_LOG_ERR("Failed to initialize UCC collective: Couldnt malloc global work buffer"); + status = DOCA_ERROR_DRIVER; + goto fail; + } + + gwbi->packed_src_memh = key_duplicate_per_rank[i].rkeys; + gwbi->packed_dst_memh = key_duplicate_per_rank[i].rkeys + key_duplicate_per_rank[i].src_len; + + coll_args->global_work_buffer = gwbi; + + if(!in_place) coll_args->src.info.count = src_thread_count; + coll_args->dst.info.count = dst_thread_count; + + if(!in_place) coll_args->src.info.buffer = src_buf + i * src_thread_size; + coll_args->dst.info.buffer = dst_buf + i * dst_thread_size; + + if (i == threads - 1) { + if(!in_place) coll_args->src.info.count += src_count % threads; + coll_args->dst.info.count += dst_count % threads; + } + + if (i == 0) { + // the threads made these teams at the same time, so their index is the same in their arrays + // TODO: is there a better way to associate these teams with each other? maybe use a map? + for (j = 0; j < ucc_worker->ucc_data[thread_ctx_id].n_teams; j++) { + if (ucc_worker->ucc_data[thread_ctx_id].ucc_team[j] == ucc_team) { + team_idx = j; + break; + } + } + } + + ucc_status = ucc_collective_init(coll_args, &coll_req, ucc_worker->ucc_data[thread_ctx_id].ucc_team[team_idx]); + if (UCC_OK != ucc_status) { + DOCA_LOG_ERR("Failed to initialize UCC collective: %s", ucc_status_string(ucc_status)); + status = DOCA_ERROR_DRIVER; + goto fail; + } + + if (thread_ctx_id >= worker_ucc_opts.num_progress_threads) { + DOCA_LOG_ERR("Warning--possible deadlock: multiple threads posting to the same queue, and the qe is going to barrier. Ensure tpp < num progress threads to avoid this\n"); + } + + lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); + while (lvalue != 0) + lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); + status = find_qe_slot(thread_ctx_id, ucc_worker, &qe); + lvalue = ucs_atomic_cswap64(&queue_lock, 1, 0); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to find queue slot for team creation"); + goto req_destroy; + } + + qe->type = UCC_WORKER_QUEUE_ELEMENT_TYPE_COLLECTIVE; + qe->coll_req = coll_req; + qe->myrank = myrank; + qe->dest_id = cmd_desc->dest_id; + qe->old_dest = NULL; + qe->data_size = 0; + qe->gwbi = gwbi; + qe->dest_packed_key = NULL; + qe->ctx_id = thread_ctx_id; + qe->in_use = 1; + qe->posted = 0; + qe->barrier = barrier; + qe->key_duplicate_per_rank = key_duplicate_per_rank; + if (i == 0) + qe->nd = nd; + else + qe->nd = NULL; + + ucs_atomic_add64(&queue_size[thread_ctx_id % worker_ucc_opts.num_progress_threads], 1); + } + + return DOCA_SUCCESS; + +req_destroy: + ucc_collective_finalize(coll_req); +fail: + notif->status = status; + ucc_worker_safe_push_notification(ucc_worker, nd); + return status; +} + + +/* + * Handle UCC collective init command + * + * @ucc_worker [in]: UCC worker context + * @cmd_desc [in]: UCC command descriptor + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t urom_worker_ucc_coll_init(struct urom_worker_ucc *ucc_worker, struct urom_worker_cmd_desc *cmd_desc) +{ + ucc_worker_key_buf *key_duplicate_per_rank; + ucc_worker_key_buf *keys; + uint64_t i; + + uint64_t lvalue; + size_t size = 0; + size_t size_mod = 8; + uint64_t ctx_id, myrank; + ucc_team_h team; + void *work_buffer; + doca_error_t status; + void *old_dest = NULL; + ucc_coll_req_h coll_req; + ucc_status_t ucc_status; + void *packed_key = NULL; + ucc_coll_args_t *coll_args; + struct ucc_queue_element *qe; + struct urom_worker_notify *notif; + struct urom_worker_notif_desc *nd; + struct urom_worker_notify_ucc *ucc_notif; + struct urom_worker_cmd *cmd = (struct urom_worker_cmd *)&cmd_desc->worker_cmd; + struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *)cmd->plugin_cmd; + ucc_tl_ucp_allreduce_sw_global_work_buf_info_t *gwbi = NULL; + int in_place = 0; + + /* Prepare notification */ + nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); + if (nd == NULL) + return DOCA_ERROR_NO_MEMORY; + + nd->dest_id = cmd_desc->dest_id; + + notif = (struct urom_worker_notify *)&nd->worker_notif; + notif->type = cmd->type; + notif->urom_context = cmd->urom_context; + notif->len = sizeof(*ucc_notif); + + ucc_notif = (struct urom_worker_notify_ucc *)notif->plugin_notif; + ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_COLLECTIVE_COMPLETE; + ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; + + if (ucc_cmd->coll_cmd.team == NULL) { + DOCA_LOG_ERR("Attempting to perform UCC collective without a UCC team"); + status = DOCA_ERROR_INVALID_VALUE; + goto fail; + } + + status = worker_ucc_dest_lookup(ucc_worker, cmd_desc->dest_id, &ctx_id); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to lookup command destination"); + goto fail; + } + + if (ucc_cmd->coll_cmd.work_buffer_size > 0 && ucc_cmd->coll_cmd.work_buffer) + work_buffer = ucc_cmd->coll_cmd.work_buffer; + else + work_buffer = NULL; + + team = ucc_cmd->coll_cmd.team; + coll_args = ucc_cmd->coll_cmd.coll_args; + myrank = ucc_cmd->dpu_worker_id; + + COLL_CHECK(ucc_worker, ctx_id, status); + + if ( (coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS ) && + (coll_args->flags & UCC_COLL_ARGS_FLAG_IN_PLACE) ) { + in_place = 1; + } + + if (coll_args->mask & UCC_COLL_ARGS_FIELD_CB) + /* Cannot support callbacks to host data.. just won't work */ + coll_args->mask = coll_args->mask & (~UCC_COLL_ARGS_FIELD_CB); + + if (coll_args->coll_type == UCC_COLL_TYPE_ALLTOALL || coll_args->coll_type == UCC_COLL_TYPE_ALLTOALLV) { + if (!ucc_cmd->coll_cmd.use_xgvmi) { + size_mod = urom_worker_get_dt_size(coll_args->src.info.datatype); + size = coll_args->src.info.count * size_mod; + if (coll_args->mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER) { + /* Perform get based on passed information */ + keys = work_buffer; + status = ucc_rma_get_host(ucc_worker->ucc_data[ctx_id].local_work_buffer, + coll_args->src.info.buffer, + size, + ctx_id, + keys->rkeys, + ucc_worker); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("UCC component unable to obtain source buffer"); + goto fail; + } + packed_key = keys->rkeys + keys->src_len; + } else { + /* Perform get based on domain information */ + status = ucc_rma_get(ucc_worker->ucc_data[ctx_id].local_work_buffer, + coll_args->src.info.buffer, + size, + MAX_HOST_DEST_ID, + myrank, + ctx_id, + ucc_worker); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("UCC component unable to obtain source buffer"); + goto fail; + } + } + coll_args->src.info.buffer = ucc_worker->ucc_data[ctx_id].local_work_buffer; + old_dest = coll_args->dst.info.buffer; + coll_args->dst.info.buffer = ucc_worker->ucc_data[ctx_id].local_work_buffer + size; + } + if (!(coll_args->mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER) || !work_buffer) { + coll_args->mask |= UCC_COLL_ARGS_FIELD_FLAGS | UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER; + coll_args->flags |= UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS; + coll_args->global_work_buffer = + ucc_worker->ucc_data[ctx_id].pSync + + (ucc_worker->ucc_data[ctx_id].psync_offset % worker_ucc_opts.num_psync); + ucc_worker->ucc_data[ctx_id].psync_offset++; + } else { + if (work_buffer != NULL) + coll_args->global_work_buffer = work_buffer; + } + } else if (coll_args->coll_type == UCC_COLL_TYPE_ALLREDUCE || coll_args->coll_type == UCC_COLL_TYPE_ALLGATHER) { + if (!ucc_cmd->coll_cmd.use_xgvmi) { + DOCA_LOG_ERR("Failed to initialize UCC collective: Allreduce must use xgvmi"); + status = DOCA_ERROR_DRIVER; + goto fail; + } + if (!(coll_args->mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER) || !work_buffer) { + DOCA_LOG_ERR("Failed to initialize UCC collective: Allreduce must use global work buffer"); + status = DOCA_ERROR_DRIVER; + goto fail; + } + + keys = work_buffer; + + gwbi = malloc(sizeof(ucc_tl_ucp_allreduce_sw_global_work_buf_info_t)); + if (gwbi == NULL) { + DOCA_LOG_ERR("Failed to initialize UCC collective: Couldnt malloc global work buffer"); + status = DOCA_ERROR_DRIVER; + goto fail; + } + + gwbi->packed_src_memh = keys->rkeys; + gwbi->packed_dst_memh = keys->rkeys + keys->src_len; + + key_duplicate_per_rank = malloc(sizeof(ucc_worker_key_buf) * worker_ucc_opts.tpp); + if (key_duplicate_per_rank == NULL) printf("couldnt malloc key_duplicate_per_rank\n"); + for (i = 0; i < worker_ucc_opts.tpp; i++) { + memcpy(key_duplicate_per_rank[i].rkeys, keys->rkeys, keys->src_len + keys->dst_len); + key_duplicate_per_rank[i].src_len = keys->src_len; + key_duplicate_per_rank[i].dst_len = keys->dst_len; + } + + status = post_nthreads_colls( + ctx_id, ucc_worker, coll_args, + team, myrank, in_place, + gwbi, nd, cmd_desc, notif, key_duplicate_per_rank); + + return status; + } + + ucc_status = ucc_collective_init(coll_args, &coll_req, team); + if (UCC_OK != ucc_status) { + DOCA_LOG_ERR("Failed to initialize UCC collective: %s", ucc_status_string(ucc_status)); + status = DOCA_ERROR_DRIVER; + goto fail; + } + + ucc_status = ucc_collective_post(coll_req); + if (UCC_OK != ucc_status) { + DOCA_LOG_ERR("Failed to post UCC collective: %s", ucc_status_string(ucc_status)); + status = DOCA_ERROR_DRIVER; + goto req_destroy; + } + + lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); + while (lvalue != 0) + lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); + status = find_qe_slot(ctx_id, ucc_worker, &qe); + lvalue = ucs_atomic_cswap64(&queue_lock, 1, 0); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to find queue slot for team creation"); + goto req_destroy; + } + + qe->type = UCC_WORKER_QUEUE_ELEMENT_TYPE_COLLECTIVE; + qe->coll_req = coll_req; + qe->myrank = myrank; + qe->dest_id = cmd_desc->dest_id; + if (!ucc_cmd->coll_cmd.use_xgvmi) { + DOCA_LOG_DBG("Setting old dest to %p", old_dest); + qe->old_dest = old_dest; + qe->data_size = size; + } else { + qe->old_dest = NULL; + qe->data_size = 0; + } + qe->gwbi = gwbi; + qe->dest_packed_key = packed_key; + qe->ctx_id = ctx_id; + qe->in_use = 1; + qe->posted = 1; + qe->barrier = NULL; + qe->nd = nd; + ucs_atomic_add64(&queue_size[ctx_id % worker_ucc_opts.num_progress_threads], 1); + return DOCA_SUCCESS; +req_destroy: + ucc_collective_finalize(coll_req); +fail: + notif->status = status; + ucc_worker_safe_push_notification(ucc_worker, nd); + return status; +} + +/* + * Handle UCC passive data channel create command + * + * @ucc_worker [in]: UCC worker context + * @cmd_desc [in]: UCC command descriptor + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t urom_worker_ucc_pass_dc_create(struct urom_worker_ucc *ucc_worker, + struct urom_worker_cmd_desc *cmd_desc) +{ + uint64_t ctx_id; + ucp_ep_h new_ep; + doca_error_t status; + ucs_status_t ucs_status; + ucp_ep_params_t ep_params; + struct urom_worker_notify *notif; + struct urom_worker_notif_desc *nd; + struct urom_worker_notify_ucc *ucc_notif; + struct urom_worker_cmd *cmd = (struct urom_worker_cmd *)&cmd_desc->worker_cmd; + struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *)cmd->plugin_cmd; + + /* Prepare notification */ + nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); + if (nd == NULL) + return DOCA_ERROR_NO_MEMORY; + + nd->dest_id = cmd_desc->dest_id; + + notif = (struct urom_worker_notify *)&nd->worker_notif; + notif->type = cmd->type; + notif->urom_context = cmd->urom_context; + notif->len = sizeof(*ucc_notif); + + ucc_notif = (struct urom_worker_notify_ucc *)notif->plugin_notif; + ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_PASSIVE_DATA_CHANNEL_COMPLETE; + ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; + + status = worker_ucc_dest_lookup(ucc_worker, cmd_desc->dest_id, &ctx_id); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to lookup command destination"); + goto fail; + } + + if (ucc_worker->ucc_data[ctx_id].host == NULL) { + ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | UCP_EP_PARAM_FIELD_ERR_HANDLER | + UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE; + ep_params.err_handler.cb = urom_ep_err_cb; + ep_params.err_handler.arg = NULL; + ep_params.err_mode = UCP_ERR_HANDLING_MODE_PEER; + ep_params.address = ucc_cmd->pass_dc_create_cmd.ucp_addr; + + ucs_status = ucp_ep_create(ucc_worker->ucp_data.ucp_worker, &ep_params, &new_ep); + if (ucs_status != UCS_OK) { + DOCA_LOG_ERR("ucp_ep_create() returned: %s", ucs_status_string(ucs_status)); + status = DOCA_ERROR_DRIVER; + goto fail; + } + + ucc_worker->ucc_data[ctx_id].host = new_ep; + DOCA_LOG_DBG("Created passive data channel for host for rank %lu", ucc_cmd->dpu_worker_id); + } else + DOCA_LOG_DBG("Passive data channel already created"); + status = DOCA_SUCCESS; +fail: + notif->status = status; + ucc_worker_safe_push_notification(ucc_worker, nd); + return status; +} + +/* + * Handle UROM UCC worker commands function + * + * @ctx [in]: DOCA UROM worker context + * @cmd_list [in]: command descriptor list to handle + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t urom_worker_ucc_worker_cmd(struct urom_worker_ctx *ctx, ucs_list_link_t *cmd_list) +{ + struct urom_worker_cmd *cmd; + doca_error_t status = DOCA_SUCCESS; + struct urom_worker_ucc_cmd *ucc_cmd; + struct urom_worker_cmd_desc *cmd_desc; + struct urom_worker_ucc *ucc_worker = (struct urom_worker_ucc *)ctx->plugin_ctx; + + while (!ucs_list_is_empty(cmd_list)) { + cmd_desc = ucs_list_extract_head(cmd_list, struct urom_worker_cmd_desc, entry); + status = urom_worker_ucc_cmd_unpack(&cmd_desc->worker_cmd, cmd_desc->worker_cmd.len, &cmd); + if (status != DOCA_SUCCESS) { + free(cmd_desc); + return status; + } + ucc_cmd = (struct urom_worker_ucc_cmd *)cmd->plugin_cmd; + switch (ucc_cmd->cmd_type) { + case UROM_WORKER_CMD_UCC_LIB_CREATE: + status = urom_worker_ucc_lib_create(ucc_worker, cmd_desc); + break; + case UROM_WORKER_CMD_UCC_LIB_DESTROY: + status = urom_worker_ucc_lib_destroy(ucc_worker, cmd_desc); + break; + case UROM_WORKER_CMD_UCC_CONTEXT_CREATE: + status = urom_worker_ucc_context_create(ucc_worker, cmd_desc); + break; + case UROM_WORKER_CMD_UCC_CONTEXT_DESTROY: + status = urom_worker_ucc_context_destroy(ucc_worker, cmd_desc); + break; + case UROM_WORKER_CMD_UCC_TEAM_CREATE: + status = urom_worker_ucc_team_create(ucc_worker, cmd_desc); + break; + case UROM_WORKER_CMD_UCC_COLL: + status = urom_worker_ucc_coll_init(ucc_worker, cmd_desc); + break; + case UROM_WORKER_CMD_UCC_CREATE_PASSIVE_DATA_CHANNEL: + status = urom_worker_ucc_pass_dc_create(ucc_worker, cmd_desc); + break; + default: + DOCA_LOG_INFO("Invalid UCC command type: %lu", ucc_cmd->cmd_type); + status = DOCA_ERROR_INVALID_VALUE; + break; + } + free(cmd_desc); + if (status != DOCA_SUCCESS) + return status; + } + + return status; +} + +/* + * Get UCC worker address + * + * UROM worker calls the function twice, first one to get address length and second one to get address data + * + * @worker_ctx [in]: DOCA UROM worker context + * @addr [out]: set worker address + * @addr_len [out]: set worker address length + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t urom_worker_ucc_addr(struct urom_worker_ctx *worker_ctx, void *addr, uint64_t *addr_len) +{ + ucs_status_t status; + struct urom_worker_ucc *ucc_worker = (struct urom_worker_ucc *)worker_ctx->plugin_ctx; + + if (ucc_worker->ucp_data.worker_address == NULL) { + status = ucp_worker_get_address(ucc_worker->ucp_data.ucp_worker, + &ucc_worker->ucp_data.worker_address, + &ucc_worker->ucp_data.ucp_addrlen); + if (status != UCS_OK) { + DOCA_LOG_ERR("Failed to get ucp worker address"); + return DOCA_ERROR_INITIALIZATION; + } + } + + if (*addr_len < ucc_worker->ucp_data.ucp_addrlen) { + /* Return required buffer size on error */ + *addr_len = ucc_worker->ucp_data.ucp_addrlen; + return DOCA_ERROR_INVALID_VALUE; + } + + *addr_len = ucc_worker->ucp_data.ucp_addrlen; + memcpy(addr, ucc_worker->ucp_data.worker_address, *addr_len); + return DOCA_SUCCESS; +} + +/* + * Check UCC worker tasks progress to get notifications + * + * @ctx [in]: DOCA UROM worker context + * @notif_list [out]: set notification descriptors for completed tasks + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t urom_worker_ucc_progress(struct urom_worker_ctx *ctx, ucs_list_link_t *notif_list) +{ + uint64_t lvalue = 0; + struct urom_worker_notif_desc *nd; + struct urom_worker_ucc *ucc_worker = (struct urom_worker_ucc *)ctx->plugin_ctx; + + if (ucs_list_is_empty(&ucc_worker->completed_reqs)) + return DOCA_ERROR_EMPTY; + + if (ucc_component_enabled) { + lvalue = ucs_atomic_cswap64(&ucc_worker->list_lock, 0, 1); + while (lvalue != 0) + lvalue = ucs_atomic_cswap64(&ucc_worker->list_lock, 0, 1); + } + + while (!ucs_list_is_empty(&ucc_worker->completed_reqs)) { + nd = ucs_list_extract_head(&ucc_worker->completed_reqs, struct urom_worker_notif_desc, entry); + ucs_list_add_tail(notif_list, &nd->entry); + } + + if (ucc_component_enabled) + lvalue = ucs_atomic_cswap64(&ucc_worker->list_lock, 1, 0); + + return DOCA_SUCCESS; +} + +/* + * Packing UCC notification + * + * @notif [in]: UCC notification to pack + * @packed_notif_len [in/out]: set packed notification command buffer size + * @packed_notif [out]: set packed notification command buffer + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t urom_worker_ucc_notif_pack(struct urom_worker_notify *notif, + size_t *packed_notif_len, + void *packed_notif) +{ + int pack_len; + void *pack_head; + void *pack_tail = packed_notif; + + /* Pack base command */ + pack_len = ucs_offsetof(struct urom_worker_notify, plugin_notif) + sizeof(struct urom_worker_notify_ucc); + pack_head = urom_ucc_serialize_next_raw(&pack_tail, void, pack_len); + + memcpy(pack_head, notif, pack_len); + *packed_notif_len = pack_len; + return DOCA_SUCCESS; +} + +/* Define UROM UCC plugin interface, set plugin functions */ +static struct urom_worker_ucc_iface urom_worker_ucc = { + .super.open = urom_worker_ucc_open, + .super.close = urom_worker_ucc_close, + .super.addr = urom_worker_ucc_addr, + .super.worker_cmd = urom_worker_ucc_worker_cmd, + .super.progress = urom_worker_ucc_progress, + .super.notif_pack = urom_worker_ucc_notif_pack, +}; + +doca_error_t urom_plugin_get_iface(struct urom_plugin_iface *iface) +{ + if (iface == NULL) + return DOCA_ERROR_INVALID_VALUE; + DOCA_STRUCT_CTOR(urom_worker_ucc.super); + *iface = urom_worker_ucc.super; + return DOCA_SUCCESS; +} + +doca_error_t urom_plugin_get_version(uint64_t *version) +{ + if (version == NULL) + return DOCA_ERROR_INVALID_VALUE; + *version = plugin_version; + return DOCA_SUCCESS; +} diff --git a/contrib/doca_urom_ucc_plugin/dpu/worker_ucc.h b/contrib/doca_urom_ucc_plugin/dpu/worker_ucc.h new file mode 100644 index 0000000000..48d1698547 --- /dev/null +++ b/contrib/doca_urom_ucc_plugin/dpu/worker_ucc.h @@ -0,0 +1,305 @@ +/* + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES, ALL RIGHTS RESERVED. + * + * This software product is a proprietary product of NVIDIA CORPORATION & + * AFFILIATES (the "Company") and all right, title, and interest in and to the + * software product, including all associated intellectual property rights, are + * and shall remain exclusively with the Company. + * + * This software product is governed by the End User License Agreement + * provided with the software product. + * + */ + +#ifndef WORKER_UCC_H_ +#define WORKER_UCC_H_ + +#include +#include +#include + +#include +#include +#include + +#include "../common/urom_ucc.h" + +#define MAX_HOST_DEST_ID INT_MAX /* Maximum destination host id */ +#define MIN_THREADS 1 /* Minimum number of threads per UCC worker */ + +/* Collective operation check macro */ +#define COLL_CHECK(ucc_worker, ctx_id, status) \ + { \ + if (ucc_worker->ucc_data[ctx_id].ucc_lib == NULL) { \ + DOCA_LOG_ERR("Attempting to perform ucc collective without initialization"); \ + status = DOCA_ERROR_NOT_FOUND; \ + goto fail; \ + } \ +\ + if (ucc_worker->ucc_data[ctx_id].ucc_context == NULL) { \ + DOCA_LOG_ERR("Attempting to perform ucc collective without a ucc context"); \ + status = DOCA_ERROR_NOT_FOUND; \ + goto fail; \ + } \ + } + +/* UCC serializing next raw, iter points to the offset place and returns the buffer start */ +#define urom_ucc_serialize_next_raw(_iter, _type, _offset) \ + ({ \ + _type *_result = (_type *)(*(_iter)); \ + *(_iter) = UCS_PTR_BYTE_OFFSET(*(_iter), _offset); \ + _result; \ + }) + +/* Worker UCC options */ +struct worker_ucc_opts { + uint64_t num_progress_threads; /* Number of threads */ + uint64_t dpu_worker_binding_stride; /* Each worker thread is bound to this far apart core # from each other */ + uint64_t ppw; /* Number of processes per worker */ + uint64_t tpp; /* Threads per host process--create this many duplicate ucc contexts/teams/collectives per single host cmd */ + uint64_t list_size; /* Size of progress list */ + uint64_t num_psync; /* Number of synchronization/work scratch buffers to allocate for collectives */ +}; + +/* UCC worker queue elements types */ +enum ucc_worker_queue_element_type { + UCC_WORKER_QUEUE_ELEMENT_TYPE_TEAM_CREATE, /* Team element queue type */ + UCC_WORKER_QUEUE_ELEMENT_TYPE_COLLECTIVE, /* Collective element queue type */ +}; + +/* UROM UCC worker interface */ +struct urom_worker_ucc_iface { + struct urom_plugin_iface super; /* DOCA UROM worker plugin interface */ +}; + +/* UCC data structure */ +struct ucc_data { + ucc_lib_h ucc_lib; /* UCC lib handle */ + ucc_lib_attr_t ucc_lib_attr; /* UCC lib attribute structure */ + ucc_context_h ucc_context; /* UCC context */ + ucc_team_h *ucc_team; /* Array of UCC team members */ + int64_t n_teams; /* Array size */ + long *pSync; /* Pointer to synchronization/work scratch buffers */ + uint64_t psync_offset; /* Synchronization buffer offset */ + void *local_work_buffer; /* Local work buffer */ + size_t len; /* Buffer length */ + ucp_ep_h host; /* The host data endpoint */ +}; + +/* EP map */ +KHASH_MAP_INIT_INT64(ep, ucp_ep_h); +/* Memory handles map */ +KHASH_MAP_INIT_INT64(memh, ucp_mem_h); +/* Remote key map */ +KHASH_MAP_INIT_INT64(rkeys, ucp_rkey_h); + +/* UCP data structure */ +struct ucc_ucp_data { + ucp_context_h ucp_context; /* UCP context */ + ucp_worker_h ucp_worker; /* UCP worker */ + ucp_address_t *worker_address; /* UCP worker address */ + size_t ucp_addrlen; /* UCP worker address length */ + khash_t(ep) * eps; /* EP hashtable map */ + khash_t(memh) * memh; /* Memh hashtable map */ + khash_t(rkeys) * rkeys; /* Rkey hashtable map */ +}; + +/* Context ids map */ +KHASH_MAP_INIT_INT64(ctx_id, uint64_t); + +struct urom_worker_ucc { + struct urom_worker_ctx *super; /* DOCA base worker context */ + struct ucc_data *ucc_data; /* UCC data structure */ + struct ucc_ucp_data ucp_data; /* UCP data structure */ + uint64_t list_lock; /* List lock field */ + ucs_list_link_t completed_reqs; /* List of completed requests */ + struct ucc_queue_element **queue; /* Elements queue */ + khash_t(ctx_id) * ids; /* Ids hashtable map */ + uint64_t ctx_id; /* Context id, incremented with every new dest id */ + uint64_t nr_connections; /* Number of connections */ +}; + +/* UCC worker thread args */ +struct ctx_thread_args { + uint64_t notif_type; /* Notification type */ + uint64_t urom_context; /* UROM context */ + int64_t start; /* Start index */ + int64_t stride; /* Number of strides */ + int64_t size; /* The work buffer size */ + int64_t myrank; /* Current thread rank */ + void *base_va; /* Buffer host address */ + size_t len; /* Total buffer length */ + uint64_t dest_id; /* Destination id */ + struct urom_worker_ucc *ucc_worker; /* UCC worker structure */ +}; + +/* UCC collective context structure */ +struct coll_ctx { + union { + int64_t start; /* Collective start for single team */ + int64_t *pids; /* Collective team pids */ + }; + int64_t stride; /* Number of strides */ + int64_t size; /* The work buffer size */ + int64_t index; /* Current collective member index */ + struct urom_worker_ucc *ucc_worker; /* UCC worker context */ +}; + +typedef struct ucc_tl_ucp_allreduce_sw_global_work_buf_info { + void *packed_src_memh; + void *packed_dst_memh; +} ucc_tl_ucp_allreduce_sw_global_work_buf_info_t; + +/* UCC queue element structure */ +struct ucc_queue_element { + enum ucc_worker_queue_element_type type; /* Element type */ + volatile int64_t in_use; /* If element in use */ + volatile int64_t posted; /* If element was posted */ + uint64_t dest_id; /* Element destination id */ + uint64_t ctx_id; /* Element context id */ + uint64_t myrank; /* Element rank */ + pthread_barrier_t *barrier; /* If not null, call this barrier */ + void *old_dest; /* Old element destination */ + size_t data_size; /* Data size */ + ucc_coll_req_h coll_req; /* UCC collective request */ + struct coll_ctx *coll_ctx; /* UCC worker collective context */ + uint64_t team_id; /* Team id */ + void *dest_packed_key; /* Destination data packed key */ + struct urom_worker_notif_desc *nd; /* Element notification descriptor */ + ucc_tl_ucp_allreduce_sw_global_work_buf_info_t *gwbi; /* gwbi ptr */ + ucc_worker_key_buf * key_duplicate_per_rank; /* per-rank copy of keys */ +}; + +/* + * Execute RMA put operation for target buffer + * + * @buffer [in]: target buffer + * @target [in]: pointer to target + * @msglen [in]: message length + * @dest [in]: destination id + * @myrank [in]: current rank in UCC team + * @ctx_id [in]: current context id + * @ucc_worker [in]: UCC worker context + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_rma_put(void *buffer, + void *target, + size_t msglen, + uint64_t dest, + uint64_t myrank, + uint64_t ctx_id, + struct urom_worker_ucc *ucc_worker); + +/* + * Execute RMA get operation on target buffer + * + * @buffer [in]: target buffer + * @target [in]: pointer to target + * @msglen [in]: message length + * @dest [in]: destination id + * @myrank [in]: current rank in UCC team + * @ctx_id [in]: current context id + * @ucc_worker [in]: UCC worker context + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_rma_get(void *buffer, + void *target, + size_t msglen, + uint64_t dest, + uint64_t myrank, + uint64_t ctx_id, + struct urom_worker_ucc *ucc_worker); + +/* + * Execute UCP send operation + * + * @msg [in]: send message + * @len [in]: message length + * @myrank [in]: current rank in UCC team + * @dest [in]: destination id + * @ucc_worker [in]: UCC worker context + * @req [out]: request result + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_send_nb(void *msg, + size_t len, + int64_t myrank, + int64_t dest, + struct urom_worker_ucc *ucc_worker, + int *req); + +/* + * Execute UCP recv operation + * + * @msg [in]: recv buffer + * @len [in]: buffer length + * @dest [in]: destination id + * @ucc_worker [in]: UCC worker context + * @req [out]: request result + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_recv_nb(void *msg, size_t len, int64_t dest, struct urom_worker_ucc *ucc_worker, int *req); + +/* + * Execute RMA get host information + * + * @buffer [in]: target buffer + * @target [in]: pointer to target + * @msglen [in]: message length + * @ctx_id [in]: context id + * @packed_key [in]: packed key + * @ucc_worker [in]: UCC worker context + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_rma_get_host(void *buffer, + void *target, + size_t msglen, + uint64_t ctx_id, + void *packed_key, + struct urom_worker_ucc *ucc_worker); + +/* + * Execute RMA put host information + * + * @buffer [in]: target buffer + * @target [in]: pointer to target + * @msglen [in]: message length + * @ctx_id [in]: context id + * @packed_key [in]: packed key + * @ucc_worker [in]: UCC worker context + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_rma_put_host(void *buffer, + void *target, + size_t msglen, + uint64_t ctx_id, + void *packed_key, + struct urom_worker_ucc *ucc_worker); + +/* + * UCP endpoint error handling context + * + * @arg [in]: user argument + * @ep [in]: EP handler + * @ucs_status [in]: UCS status + */ +void urom_ep_err_cb(void *arg, ucp_ep_h ep, ucs_status_t ucs_status); + +/* + * Get DOCA worker plugin interface for UCC plugin. + * DOCA UROM worker will load the urom_plugin_get_iface symbol to get the UCC interface + * + * @iface [out]: Set DOCA UROM plugin interface for UCC + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t urom_plugin_get_iface(struct urom_plugin_iface *iface); + +/* + * Get UCC plugin version, will be used to verify that the host and DPU plugin versions are compatible + * + * @version [out]: Set the UCC worker plugin version + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t urom_plugin_get_version(uint64_t *version); + +#endif /* WORKER_UCC_H_ */ diff --git a/contrib/doca_urom_ucc_plugin/dpu/worker_ucc_p2p.c b/contrib/doca_urom_ucc_plugin/dpu/worker_ucc_p2p.c new file mode 100644 index 0000000000..c4cb38d35c --- /dev/null +++ b/contrib/doca_urom_ucc_plugin/dpu/worker_ucc_p2p.c @@ -0,0 +1,524 @@ +/* + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES, ALL RIGHTS RESERVED. + * + * This software product is a proprietary product of NVIDIA CORPORATION & + * AFFILIATES (the "Company") and all right, title, and interest in and to the + * software product, including all associated intellectual property rights, are + * and shall remain exclusively with the Company. + * + * This software product is governed by the End User License Agreement + * provided with the software product. + * + */ + +#include +#include + +#include "worker_ucc.h" +#include "../common/urom_ucc.h" + +DOCA_LOG_REGISTER(UCC::DOCA_CL : WORKER_UCC_P2P); + +void urom_ep_err_cb(void *arg, ucp_ep_h ep, ucs_status_t ucs_status) +{ + (void)arg; + (void)ep; + + DOCA_LOG_ERR("Endpoint error detected, status: %s", ucs_status_string(ucs_status)); +} + +/* + * UCC worker EP lookup function + * + * @ucc_worker [in]: UCC worker context + * @dest [in]: destination id + * @ep [out]: set UCP endpoint + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t worker_ucc_ep_lookup(struct urom_worker_ucc *ucc_worker, uint64_t dest, ucp_ep_h *ep) +{ + int ret; + khint_t k; + void *addr; + ucp_ep_h new_ep; + doca_error_t status; + ucs_status_t ucs_status; + ucp_ep_params_t ep_params; + + k = kh_get(ep, ucc_worker->ucp_data.eps, dest); + if (k != kh_end(ucc_worker->ucp_data.eps)) { + *ep = kh_value(ucc_worker->ucp_data.eps, k); + return DOCA_SUCCESS; + } + + /* Create new EP */ + status = doca_urom_worker_domain_addr_lookup(ucc_worker->super, dest, &addr); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Id not found in domain:: %#lx", dest); + return DOCA_ERROR_NOT_FOUND; + } + + ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | UCP_EP_PARAM_FIELD_ERR_HANDLER | + UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE; + ep_params.err_handler.cb = urom_ep_err_cb; + ep_params.err_handler.arg = NULL; + ep_params.err_mode = UCP_ERR_HANDLING_MODE_PEER; + ep_params.address = addr; + + ucs_status = ucp_ep_create(ucc_worker->ucp_data.ucp_worker, &ep_params, &new_ep); + if (ucs_status != UCS_OK) { + DOCA_LOG_ERR("ucp_ep_create() returned: %s", ucs_status_string(ucs_status)); + return DOCA_ERROR_INITIALIZATION; + } + + k = kh_put(ep, ucc_worker->ucp_data.eps, dest, &ret); + if (ret <= 0) + return DOCA_ERROR_DRIVER; + kh_value(ucc_worker->ucp_data.eps, k) = new_ep; + + *ep = new_ep; + DOCA_LOG_DBG("Created EP for dest: %#lx", dest); + return DOCA_SUCCESS; +} + +/* + * UCC worker memh lookup function + * + * @ucc_worker [in]: UCC worker context + * @dest [in]: destination id + * @memh [out]: set memory handle + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t worker_ucc_memh_lookup(struct urom_worker_ucc *ucc_worker, uint64_t dest, ucp_mem_h *memh) +{ + int ret; + khint_t k; + void *mem_handle; + ucp_mem_h memh_id; + doca_error_t status; + size_t memh_len = 0; + ucs_status_t ucs_status; + ucp_mem_map_params_t mmap_params = {0}; + + k = kh_get(memh, ucc_worker->ucp_data.memh, dest); + if (k != kh_end(ucc_worker->ucp_data.memh)) { + *memh = kh_value(ucc_worker->ucp_data.memh, k); + return DOCA_SUCCESS; + } + + /* Lookup memory handle */ + status = doca_urom_worker_domain_memh_lookup(ucc_worker->super, dest, 0, &memh_len, &mem_handle); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Id not found in domain:: %#lx", dest); + return DOCA_ERROR_NOT_FOUND; + } + + mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER; + mmap_params.exported_memh_buffer = mem_handle; + + ucs_status = ucp_mem_map(ucc_worker->ucp_data.ucp_context, &mmap_params, &memh_id); + if (ucs_status != UCS_OK) { + DOCA_LOG_ERR("Failed to map packed memh %p", mem_handle); + return DOCA_ERROR_NOT_FOUND; + } + + k = kh_put(memh, ucc_worker->ucp_data.memh, dest, &ret); + if (ret <= 0) { + DOCA_LOG_ERR("Failed to add memh to hashtable map"); + if (ucp_mem_unmap(ucc_worker->ucp_data.ucp_context, memh_id) != UCS_OK) + DOCA_LOG_ERR("Failed to unmap memh"); + return DOCA_ERROR_DRIVER; + } + kh_value(ucc_worker->ucp_data.memh, k) = memh_id; + + *memh = memh_id; + DOCA_LOG_DBG("Assigned memh %p for dest: %#lx", memh_id, dest); + return DOCA_SUCCESS; +} + +/* + * UCC worker memory key lookup function + * + * @ucc_worker [in]: UCC worker context + * @dest [in]: destination id + * @ep [in]: destination endpoint + * @va [in]: memory host address + * @ret_rkey [out]: set remote memory key + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t worker_ucc_key_lookup(struct urom_worker_ucc *ucc_worker, + uint64_t dest, + ucp_ep_h ep, + uint64_t va, + void **ret_rkey) +{ + khint_t k; + int ret; + void *packed_key; + size_t packed_key_len; + ucp_rkey_h rkey; + doca_error_t status; + ucs_status_t ucs_status; + int seg; + + k = kh_get(rkeys, ucc_worker->ucp_data.rkeys, dest); + if (k != kh_end(ucc_worker->ucp_data.rkeys)) { + *ret_rkey = kh_value(ucc_worker->ucp_data.rkeys, k); + return DOCA_SUCCESS; + } + + status = doca_urom_worker_domain_seg_lookup(ucc_worker->super, dest, va, &seg); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Id not found in domain: %#lx", dest); + return DOCA_ERROR_NOT_FOUND; + } + + status = doca_urom_worker_domain_mkey_lookup(ucc_worker->super, dest, seg, &packed_key_len, &packed_key); + if (status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Id not found in domain: %#lx", dest); + return DOCA_ERROR_NOT_FOUND; + } + + ucs_status = ucp_ep_rkey_unpack(ep, packed_key, &rkey); + if (ucs_status != UCS_OK) + return DOCA_ERROR_NOT_FOUND; + + k = kh_put(rkeys, ucc_worker->ucp_data.rkeys, dest, &ret); + if (ret <= 0) + return DOCA_ERROR_DRIVER; + kh_value(ucc_worker->ucp_data.rkeys, k) = rkey; + + *ret_rkey = rkey; + DOCA_LOG_DBG("Assigned rkey for dest: %#lx", dest); + return DOCA_SUCCESS; +} + +/* + * UCC send tag completion callback + * + * @request [in]: UCP send request + * @status [in]: send task status + * @user_data [in]: UCC data + */ +static void send_completion_cb(void *request, ucs_status_t status, void *user_data) +{ + int *req = (int *)user_data; + + if (status != UCS_OK) + *req = -1; + else + *req = 1; + + ucp_request_free(request); +} + +/* + * UCC recv tag completion callback + * + * @request [in]: UCP recv request + * @status [in]: recv task status + * @info [in]: recv task info + * @user_data [in]: UCC data + */ +static void recv_completion_cb(void *request, ucs_status_t status, const ucp_tag_recv_info_t *info, void *user_data) +{ + (void)info; + int *req = (int *)user_data; + + if (status != UCS_OK) + *req = -1; + else + *req = 1; + + ucp_request_free(request); +} + +doca_error_t ucc_send_nb(void *msg, + size_t len, + int64_t myrank, + int64_t dest, + struct urom_worker_ucc *ucc_worker, + int *req) +{ + ucp_ep_h ep = NULL; + doca_error_t urom_status; + ucs_status_ptr_t ucp_status; + ucp_request_param_t req_param = {0}; + + *req = 0; + req_param.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA; + req_param.datatype = ucp_dt_make_contig(len); + req_param.cb.send = send_completion_cb; + req_param.user_data = (void *)req; + + urom_status = worker_ucc_ep_lookup(ucc_worker, dest, &ep); + if (urom_status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to send to %ld in UCC oob", dest); + return DOCA_ERROR_NOT_FOUND; + } + + /* Process tag send */ + ucp_status = ucp_tag_send_nbx(ep, msg, 1, myrank, &req_param); + if (ucp_status != UCS_OK) { + if (UCS_PTR_IS_ERR(ucp_status)) { + ucp_request_cancel(ucc_worker->ucp_data.ucp_worker, ucp_status); + ucp_request_free(ucp_status); + return DOCA_ERROR_NOT_FOUND; + } + } else + *req = 1; + + return DOCA_SUCCESS; +} + +doca_error_t ucc_recv_nb(void *msg, size_t len, int64_t dest, struct urom_worker_ucc *ucc_worker, int *req) +{ + ucs_status_ptr_t ucp_status; + ucp_request_param_t req_param = {}; + + *req = 0; + req_param.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA; + req_param.datatype = ucp_dt_make_contig(len); + req_param.cb.recv = recv_completion_cb; + req_param.user_data = (void *)req; + + /* Process tag recv */ + ucp_status = ucp_tag_recv_nbx(ucc_worker->ucp_data.ucp_worker, msg, 1, dest, 0xffff, &req_param); + if (ucp_status != UCS_OK) { + if (UCS_PTR_IS_ERR(ucp_status)) { + ucp_request_cancel(ucc_worker->ucp_data.ucp_worker, ucp_status); + ucp_request_free(ucp_status); + return DOCA_ERROR_NOT_FOUND; + } + } else + *req = 1; + return DOCA_SUCCESS; +} + +doca_error_t ucc_rma_put(void *buffer, + void *target, + size_t msglen, + uint64_t dest, + uint64_t myrank, + uint64_t ctx_id, + struct urom_worker_ucc *ucc_worker) +{ + ucp_ep_h ep; + ucp_rkey_h rkey; + ucp_mem_h memh = NULL; + doca_error_t urom_status; + ucs_status_ptr_t ucp_status; + uint64_t rva = (uint64_t)target; + ucp_request_param_t req_param = {0}; + + if (dest == MAX_HOST_DEST_ID) + ep = ucc_worker->ucc_data[ctx_id].host; + else { + urom_status = worker_ucc_ep_lookup(ucc_worker, dest, &ep); + if (urom_status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to find peer %ld to complete collective", dest); + return DOCA_ERROR_NOT_FOUND; + } + } + + if (dest != MAX_HOST_DEST_ID) { + urom_status = worker_ucc_memh_lookup(ucc_worker, dest, &memh); + if (urom_status != DOCA_SUCCESS) + DOCA_LOG_ERR("Failed to lookup key for peer %ld", dest); + } + + if (dest == MAX_HOST_DEST_ID) { + urom_status = worker_ucc_key_lookup(ucc_worker, myrank, ep, rva, (void **)&rkey); + if (urom_status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to lookup rkey for peer %ld", dest); + return DOCA_ERROR_NOT_FOUND; + } + } else { + urom_status = worker_ucc_key_lookup(ucc_worker, dest, ep, rva, (void **)&rkey); + if (urom_status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to lookup rkey for peer %ld", dest); + return DOCA_ERROR_NOT_FOUND; + } + } + + if (memh != NULL) { + req_param.op_attr_mask = UCP_OP_ATTR_FIELD_MEMH; + req_param.memh = memh; + } + + ucp_status = ucp_put_nbx(ep, buffer, msglen, rva, rkey, &req_param); + if (ucp_status != UCS_OK) { + if (UCS_PTR_IS_ERR(ucp_status)) { + ucp_request_free(ucp_status); + return DOCA_ERROR_NOT_FOUND; + } + while (ucp_request_check_status(ucp_status) == UCS_INPROGRESS) + ucp_worker_progress(ucc_worker->ucp_data.ucp_worker); + ucp_request_free(ucp_status); + } + return DOCA_SUCCESS; +} + +doca_error_t ucc_rma_get(void *buffer, + void *target, + size_t msglen, + uint64_t dest, + uint64_t myrank, + uint64_t ctx_id, + struct urom_worker_ucc *ucc_worker) +{ + ucp_ep_h ep = NULL; + ucp_mem_h memh = NULL; + ucp_rkey_h rkey = NULL; + doca_error_t urom_status; + ucs_status_ptr_t ucp_status; + uint64_t rva = (uint64_t)target; + ucp_request_param_t req_param = {0}; + + if (dest == MAX_HOST_DEST_ID) + ep = ucc_worker->ucc_data[ctx_id].host; + else { + urom_status = worker_ucc_ep_lookup(ucc_worker, dest, &ep); + if (urom_status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to find peer %ld to complete collective", dest); + return DOCA_ERROR_NOT_FOUND; + } + } + + if (dest != MAX_HOST_DEST_ID) { + urom_status = worker_ucc_memh_lookup(ucc_worker, dest, &memh); + if (urom_status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to lookup key for peer %ld", dest); + return DOCA_ERROR_NOT_FOUND; + } + } + + if (dest == MAX_HOST_DEST_ID) { + urom_status = worker_ucc_key_lookup(ucc_worker, myrank, ep, rva, (void **)&rkey); + if (urom_status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to lookup rkey for peer %ld", dest); + return DOCA_ERROR_NOT_FOUND; + } + } else { + urom_status = worker_ucc_key_lookup(ucc_worker, dest, ep, rva, (void **)&rkey); + if (urom_status != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to lookup rkey for peer %ld", dest); + return DOCA_ERROR_NOT_FOUND; + } + } + + if (memh != NULL) { + req_param.op_attr_mask = UCP_OP_ATTR_FIELD_MEMH; + req_param.memh = memh; + } + + ucp_status = ucp_get_nbx(ep, buffer, msglen, rva, rkey, &req_param); + if (ucp_status != UCS_OK) { + if (UCS_PTR_IS_ERR(ucp_status)) { + ucp_request_free(ucp_status); + ucp_rkey_destroy(rkey); + return DOCA_ERROR_NOT_FOUND; + } + while (ucp_request_check_status(ucp_status) == UCS_INPROGRESS) + ucp_worker_progress(ucc_worker->ucp_data.ucp_worker); + ucp_request_free(ucp_status); + } + ucp_rkey_destroy(rkey); + return DOCA_SUCCESS; +} + +doca_error_t ucc_rma_get_host(void *buffer, + void *target, + size_t msglen, + uint64_t ctx_id, + void *packed_key, + struct urom_worker_ucc *ucc_worker) +{ + ucp_ep_h ep = NULL; + ucp_rkey_h rkey = NULL; + ucs_status_t ucs_status; + ucs_status_ptr_t ucp_status; + uint64_t rva = (uint64_t)target; + ucp_request_param_t req_param = {0}; + + if (packed_key == NULL) + return DOCA_ERROR_INVALID_VALUE; + + ep = ucc_worker->ucc_data[ctx_id].host; + + ucs_status = ucp_ep_rkey_unpack(ep, packed_key, &rkey); + if (ucs_status != UCS_OK) { + DOCA_LOG_ERR("Failed to unpack rkey"); + return DOCA_ERROR_NOT_FOUND; + } + + ucp_status = ucp_get_nbx(ep, buffer, msglen, rva, rkey, &req_param); + if (UCS_OK != ucp_status) { + if (UCS_PTR_IS_ERR(ucp_status)) { + DOCA_LOG_ERR("Failed to perform ucp_get_nbx(): %s\n", ucs_status_string(UCS_PTR_STATUS(ucp_status))); + ucp_request_free(ucp_status); + ucp_rkey_destroy(rkey); + return DOCA_ERROR_NOT_FOUND; + } + while (UCS_INPROGRESS == ucp_request_check_status(ucp_status)) + ucp_worker_progress(ucc_worker->ucp_data.ucp_worker); + + if (UCS_PTR_IS_ERR(ucp_status)) { + DOCA_LOG_ERR("Failed to perform ucp_get_nbx(): %s\n", ucs_status_string(UCS_PTR_STATUS(ucp_status))); + ucp_request_free(ucp_status); + ucp_rkey_destroy(rkey); + return DOCA_ERROR_NOT_FOUND; + } + ucp_request_free(ucp_status); + } + ucp_rkey_destroy(rkey); + return DOCA_SUCCESS; +} + +doca_error_t ucc_rma_put_host(void *buffer, + void *target, + size_t msglen, + uint64_t ctx_id, + void *packed_key, + struct urom_worker_ucc *ucc_worker) +{ + ucp_ep_h ep = NULL; + ucp_rkey_h rkey = NULL; + ucs_status_t ucs_status; + ucs_status_ptr_t ucp_status; + uint64_t rva = (uint64_t)target; + ucp_request_param_t req_param = {0}; + + if (packed_key == NULL) + return DOCA_ERROR_INVALID_VALUE; + + ep = ucc_worker->ucc_data[ctx_id].host; + + ucs_status = ucp_ep_rkey_unpack(ep, packed_key, &rkey); + if (ucs_status != UCS_OK) { + DOCA_LOG_ERR("Failed to unpack rkey"); + return DOCA_ERROR_NOT_FOUND; + } + + ucp_status = ucp_put_nbx(ep, buffer, msglen, rva, rkey, &req_param); + if (UCS_OK != ucp_status) { + if (UCS_PTR_IS_ERR(ucp_status)) { + DOCA_LOG_ERR("Failed to perform ucp_put_nbx(): %s\n", ucs_status_string(UCS_PTR_STATUS(ucp_status))); + ucp_request_free(ucp_status); + ucp_rkey_destroy(rkey); + return DOCA_ERROR_NOT_FOUND; + } + while (UCS_INPROGRESS == ucp_request_check_status(ucp_status)) { + ucp_worker_progress(ucc_worker->ucp_data.ucp_worker); + } + if (UCS_PTR_IS_ERR(ucp_status)) { + DOCA_LOG_ERR("Failed to perform ucp_put_nbx(): %s\n", ucs_status_string(UCS_PTR_STATUS(ucp_status))); + ucp_request_free(ucp_status); + ucp_rkey_destroy(rkey); + return DOCA_ERROR_NOT_FOUND; + } + ucp_request_free(ucp_status); + } + ucp_rkey_destroy(rkey); + return DOCA_SUCCESS; +} diff --git a/src/Makefile.am b/src/Makefile.am index c505c31344..c6c9cb01db 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -14,6 +14,10 @@ mc_dirs += components/mc/cuda ec_dirs += components/ec/cuda endif +if HAVE_DOCA_UROM +cl_dirs += components/cl/doca_urom +endif + if HAVE_ROCM mc_dirs += components/mc/rocm ec_dirs += components/ec/rocm diff --git a/src/components/cl/doca_urom/Makefile.am b/src/components/cl/doca_urom/Makefile.am new file mode 100644 index 0000000000..a7f5589ba1 --- /dev/null +++ b/src/components/cl/doca_urom/Makefile.am @@ -0,0 +1,24 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# + +sources = \ + cl_doca_urom.h \ + cl_doca_urom.c \ + cl_doca_urom_lib.c \ + cl_doca_urom_context.c \ + cl_doca_urom_team.c \ + cl_doca_urom_common.c \ + cl_doca_urom_common.h \ + cl_doca_urom_worker_ucc.c \ + cl_doca_urom_worker_ucc.h \ + cl_doca_urom_coll.c + +module_LTLIBRARIES = libucc_cl_doca_urom.la +libucc_cl_doca_urom_la_SOURCES = $(sources) +libucc_cl_doca_urom_la_CPPFLAGS = $(AM_CPPFLAGS) $(BASE_CPPFLAGS) $(DOCA_UROM_CPPFLAGS) $(DOCA_UROM_UCC_CPPFLAGS) -I$(top_srcdir)/contrib/doca_urom_ucc_plugin/common +libucc_cl_doca_urom_la_CFLAGS = $(BASE_CFLAGS) +libucc_cl_doca_urom_la_LDFLAGS = -version-info $(SOVERSION) --as-needed $(DOCA_UROM_UCC_LDFLAGS) $(DOCA_UROM_LDFLAGS) +libucc_cl_doca_urom_la_LIBADD = $(DOCA_UROM_UCC_LIBADD) $(DOCA_UROM_LIBADD) $(UCC_TOP_BUILDDIR)/src/libucc.la + +include $(top_srcdir)/config/module.am diff --git a/src/components/cl/doca_urom/cl_doca_urom.c b/src/components/cl/doca_urom/cl_doca_urom.c new file mode 100644 index 0000000000..ae09833074 --- /dev/null +++ b/src/components/cl/doca_urom/cl_doca_urom.c @@ -0,0 +1,76 @@ +/** + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "cl_doca_urom.h" +#include "utils/ucc_malloc.h" + +ucc_status_t ucc_cl_doca_urom_get_lib_attr(const ucc_base_lib_t *lib, + ucc_base_lib_attr_t *base_attr); + +ucc_status_t ucc_cl_doca_urom_get_context_attr(const ucc_base_context_t *context, + ucc_base_ctx_attr_t *base_attr); + +ucc_status_t ucc_cl_doca_urom_get_lib_properties(ucc_base_lib_properties_t *prop); + +static ucc_config_field_t ucc_cl_doca_urom_lib_config_table[] = { + {"", "", NULL, ucc_offsetof(ucc_cl_doca_urom_lib_config_t, super), + UCC_CONFIG_TYPE_TABLE(ucc_cl_lib_config_table)}, + + {NULL}}; + +static ucs_config_field_t ucc_cl_doca_urom_context_config_table[] = { + {"", "", NULL, ucc_offsetof(ucc_cl_doca_urom_context_config_t, super), + UCC_CONFIG_TYPE_TABLE(ucc_cl_context_config_table)}, + + {"PLUGIN_ENVS", "", + "Comma separated envs to pass to the worker plugin", + ucc_offsetof(ucc_cl_doca_urom_context_config_t, plugin_envs), + UCC_CONFIG_TYPE_STRING_ARRAY}, + + {"DEVICE", "mlx5_0", + "DPU device", + ucc_offsetof(ucc_cl_doca_urom_context_config_t, device), + UCC_CONFIG_TYPE_STRING}, + + {"PLUGIN_NAME", "libucc_doca_urom_plugin", + "Name of plugin library", + ucc_offsetof(ucc_cl_doca_urom_context_config_t, plugin_name), + UCC_CONFIG_TYPE_STRING}, + + {"DOCA_LOG_LEVEL", "10", + "DOCA log level", + ucc_offsetof(ucc_cl_doca_urom_context_config_t, doca_log_level), + UCC_CONFIG_TYPE_INT}, + + {NULL}}; + +UCC_CLASS_DEFINE_NEW_FUNC(ucc_cl_doca_urom_lib_t, ucc_base_lib_t, + const ucc_base_lib_params_t *, + const ucc_base_config_t *); + +UCC_CLASS_DEFINE_DELETE_FUNC(ucc_cl_doca_urom_lib_t, ucc_base_lib_t); + +UCC_CLASS_DEFINE_NEW_FUNC(ucc_cl_doca_urom_context_t, ucc_base_context_t, + const ucc_base_context_params_t *, + const ucc_base_config_t *); + +UCC_CLASS_DEFINE_DELETE_FUNC(ucc_cl_doca_urom_context_t, ucc_base_context_t); + +UCC_CLASS_DEFINE_NEW_FUNC(ucc_cl_doca_urom_team_t, ucc_base_team_t, + ucc_base_context_t *, const ucc_base_team_params_t *); + +ucc_status_t ucc_cl_doca_urom_team_create_test(ucc_base_team_t *cl_team); + +ucc_status_t ucc_cl_doca_urom_team_destroy(ucc_base_team_t *cl_team); + +ucc_status_t ucc_cl_doca_urom_coll_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task); + +ucc_status_t ucc_cl_doca_urom_team_get_scores(ucc_base_team_t *cl_team, + ucc_coll_score_t **score); + +UCC_CL_IFACE_DECLARE(doca_urom, DOCA_UROM); diff --git a/src/components/cl/doca_urom/cl_doca_urom.h b/src/components/cl/doca_urom/cl_doca_urom.h new file mode 100644 index 0000000000..ac589fd775 --- /dev/null +++ b/src/components/cl/doca_urom/cl_doca_urom.h @@ -0,0 +1,104 @@ +/** + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#ifndef UCC_CL_DOCA_UROM_H_ +#define UCC_CL_DOCA_UROM_H_ + +#include "components/cl/ucc_cl.h" +#include "components/cl/ucc_cl_log.h" +#include "components/tl/ucc_tl.h" +#include "coll_score/ucc_coll_score.h" +#include "utils/ucc_mpool.h" + +#include +#include +#include +#include +#include + +#include "cl_doca_urom_common.h" +#include "cl_doca_urom_worker_ucc.h" + +#include + +#ifndef UCC_CL_DOCA_UROM_DEFAULT_SCORE +#define UCC_CL_DOCA_UROM_DEFAULT_SCORE 100 +#endif + +#define UCC_CL_DOCA_UROM_ADDR_MAX_LEN 1024 +#define UCC_CL_DOCA_UROM_MAX_TEAMS 16 + +typedef struct ucc_cl_doca_urom_iface { + ucc_cl_iface_t super; +} ucc_cl_doca_urom_iface_t; + +// Extern iface should follow the pattern: ucc_cl_ +extern ucc_cl_doca_urom_iface_t ucc_cl_doca_urom; + +typedef struct ucc_cl_doca_urom_lib_config { + ucc_cl_lib_config_t super; +} ucc_cl_doca_urom_lib_config_t; + +typedef struct ucc_cl_doca_urom_context_config { + ucc_cl_context_config_t super; + ucs_config_names_array_t plugin_envs; + char *device; + char *plugin_name; + int doca_log_level; +} ucc_cl_doca_urom_context_config_t; + +typedef struct ucc_cl_doca_urom_ctx { + struct doca_urom_service *urom_service; + struct doca_urom_worker *urom_worker; + struct doca_urom_domain *urom_domain; + struct doca_pe *urom_pe; + const struct doca_urom_service_plugin_info *ucc_info; + void *urom_worker_addr; + size_t urom_worker_len; + uint64_t worker_id; + void *urom_ucc_context; + ucc_rank_t ctx_rank; + struct doca_dev *dev; +} ucc_cl_doca_urom_ctx_t; + +typedef struct ucc_cl_doca_urom_lib { + ucc_cl_lib_t super; + ucc_cl_doca_urom_lib_config_t cfg; + int tl_ucp_index; +} ucc_cl_doca_urom_lib_t; +UCC_CLASS_DECLARE(ucc_cl_doca_urom_lib_t, const ucc_base_lib_params_t *, + const ucc_base_config_t *); + +typedef struct ucc_cl_doca_urom_context { + ucc_cl_context_t super; + void *urom_ucc_ctx_h; + ucc_mpool_t sched_mp; + ucp_context_h ucp_context; + ucc_cl_doca_urom_ctx_t urom_ctx; + ucc_cl_doca_urom_context_config_t cfg; +} ucc_cl_doca_urom_context_t; +UCC_CLASS_DECLARE(ucc_cl_doca_urom_context_t, const ucc_base_context_params_t *, + const ucc_base_config_t *); + +typedef struct ucc_cl_doca_urom_team { + ucc_cl_team_t super; + ucc_team_h **teams; + unsigned n_teams; + ucc_coll_score_t *score; + ucc_score_map_t *score_map; + struct ucc_cl_doca_urom_result res; // used for the cookie +} ucc_cl_doca_urom_team_t; +UCC_CLASS_DECLARE(ucc_cl_doca_urom_team_t, ucc_base_context_t *, + const ucc_base_team_params_t *); + +ucc_status_t ucc_cl_doca_urom_coll_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task); + +#define UCC_CL_DOCA_UROM_TEAM_CTX(_team) \ + (ucc_derived_of((_team)->super.super.context, ucc_cl_doca_urom_context_t)) + +#endif diff --git a/src/components/cl/doca_urom/cl_doca_urom_coll.c b/src/components/cl/doca_urom/cl_doca_urom_coll.c new file mode 100644 index 0000000000..77d85adbc3 --- /dev/null +++ b/src/components/cl/doca_urom/cl_doca_urom_coll.c @@ -0,0 +1,262 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "cl_doca_urom.h" +#include "cl_doca_urom_coll.h" +#include "utils/ucc_coll_utils.h" + +#include +#include + +static ucc_status_t ucc_cl_doca_urom_triggered_post_setup(ucc_coll_task_t *task) +{ + return UCC_OK; +} + +static ucc_status_t ucc_cl_doca_urom_coll_full_start(ucc_coll_task_t *task) +{ + ucc_cl_doca_urom_team_t *cl_team = ucc_derived_of(task->team, + ucc_cl_doca_urom_team_t); + ucc_cl_doca_urom_context_t *ctx = UCC_CL_DOCA_UROM_TEAM_CTX(cl_team); + ucc_cl_doca_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, + ucc_cl_doca_urom_lib_t); + ucc_coll_args_t *coll_args = &task->bargs.args; + int ucp_index = cl_lib->tl_ucp_index; + ucc_tl_ucp_context_t *tl_ctx = ucc_derived_of( + ctx->super.tl_ctxs[ucp_index], + ucc_tl_ucp_context_t); + union doca_data cookie = {0}; + int use_xgvmi = 0; + int in_place = 0; + ucc_rank_t rank = UCC_CL_TEAM_RANK(cl_team); + ucc_cl_doca_urom_schedule_t *schedule = ucc_derived_of(task, + ucc_cl_doca_urom_schedule_t); + struct export_buf *src_ebuf = &schedule->src_ebuf; + struct export_buf *dst_ebuf = &schedule->dst_ebuf; + doca_error_t result; + ucc_worker_key_buf keys; + + src_ebuf->memh = NULL; + dst_ebuf->memh = NULL; + + cookie.ptr = &schedule->res; + + if ( (coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS ) && + (coll_args->flags & UCC_COLL_ARGS_FLAG_IN_PLACE) ) { + in_place = 1; + } + + if (!in_place) { + ucc_cl_doca_urom_buffer_export_ucc( + tl_ctx->worker.ucp_context, + coll_args->src.info.buffer, + coll_args->src.info.count * + ucc_dt_size(coll_args->src.info.datatype), + src_ebuf); + } + + ucc_cl_doca_urom_buffer_export_ucc( + tl_ctx->worker.ucp_context, + coll_args->dst.info.buffer, + coll_args->dst.info.count * + ucc_dt_size(coll_args->dst.info.datatype), + dst_ebuf); + + switch (coll_args->coll_type) { + case UCC_COLL_TYPE_ALLTOALL: + { + if (!in_place) { + keys.src_len = src_ebuf->packed_key_len; + memcpy(keys.rkeys, src_ebuf->packed_key, keys.src_len); + } else { + keys.src_len = 0; + } + keys.dst_len = dst_ebuf->packed_key_len; + memcpy(keys.rkeys + keys.src_len, + dst_ebuf->packed_key, keys.dst_len); + use_xgvmi = 0; + } break; + case UCC_COLL_TYPE_ALLREDUCE: + { + if (!in_place) { + keys.src_len = src_ebuf->packed_memh_len; + memcpy(keys.rkeys, src_ebuf->packed_memh, keys.src_len); + } else { + keys.src_len = 0; + } + keys.dst_len = dst_ebuf->packed_memh_len; + memcpy(keys.rkeys + keys.src_len, + dst_ebuf->packed_memh, keys.dst_len); + use_xgvmi = 1; + } break; + case UCC_COLL_TYPE_ALLGATHER: + { + if (!in_place) { + keys.src_len = src_ebuf->packed_memh_len; + memcpy(keys.rkeys, src_ebuf->packed_memh, keys.src_len); + } else { + keys.src_len = 0; + } + keys.dst_len = dst_ebuf->packed_memh_len; + memcpy(keys.rkeys + keys.src_len, + dst_ebuf->packed_memh, + keys.dst_len); + use_xgvmi = 1; + } break; + default: + cl_error(&cl_lib->super, "coll_type %s is not supported", + ucc_coll_type_str(coll_args->coll_type)); + } + + coll_args->mask |= UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER; + + result = ucc_cl_doca_urom_task_collective(ctx->urom_ctx.urom_worker, + cookie, + rank, + coll_args, + cl_team->teams[0], + use_xgvmi, + &keys, + sizeof(ucc_worker_key_buf), + 0, + ucc_cl_doca_urom_collective_finished); + if (result != DOCA_SUCCESS) { + cl_error(&cl_lib->super, "Failed to create UCC collective task"); + } + + task->status = UCC_INPROGRESS; + + cl_debug(&cl_lib->super, "pushed the collective to urom"); + return ucc_progress_queue_enqueue(ctx->super.super.ucc_context->pq, task); +} + +static ucc_status_t ucc_cl_doca_urom_coll_full_finalize(ucc_coll_task_t *task) +{ + ucc_cl_doca_urom_schedule_t *schedule = ucc_derived_of(task, + ucc_cl_doca_urom_schedule_t); + ucc_cl_doca_urom_team_t *cl_team = ucc_derived_of(task->team, + ucc_cl_doca_urom_team_t); + ucc_cl_doca_urom_context_t *ctx = UCC_CL_DOCA_UROM_TEAM_CTX(cl_team); + ucc_cl_doca_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, + ucc_cl_doca_urom_lib_t); + int ucp_index = cl_lib->tl_ucp_index; + ucc_tl_ucp_context_t *tl_ctx = ucc_derived_of( + ctx->super.tl_ctxs[ucp_index], + ucc_tl_ucp_context_t); + struct export_buf *src_ebuf = &schedule->src_ebuf; + struct export_buf *dst_ebuf = &schedule->dst_ebuf; + ucc_status_t status; + + if (src_ebuf->memh) { + ucp_mem_unmap(tl_ctx->worker.ucp_context, src_ebuf->memh); + } + ucp_mem_unmap(tl_ctx->worker.ucp_context, dst_ebuf->memh); + + status = ucc_schedule_finalize(task); + ucc_cl_doca_urom_put_schedule(&schedule->super.super); + + return status; +} + +static void ucc_cl_doca_urom_coll_full_progress(ucc_coll_task_t *ctask) +{ + ucc_cl_doca_urom_team_t *cl_team = ucc_derived_of(ctask->team, + ucc_cl_doca_urom_team_t); + ucc_cl_doca_urom_context_t *ctx = UCC_CL_DOCA_UROM_TEAM_CTX(cl_team); + ucc_cl_doca_urom_lib_t *cl_lib = ucc_derived_of( + ctx->super.super.lib, + ucc_cl_doca_urom_lib_t); + ucc_cl_doca_urom_schedule_t *schedule = ucc_derived_of(ctask, + ucc_cl_doca_urom_schedule_t); + int ucp_index = cl_lib->tl_ucp_index; + ucc_tl_ucp_context_t *tl_ctx = ucc_derived_of( + ctx->super.tl_ctxs[ucp_index], + ucc_tl_ucp_context_t); + struct ucc_cl_doca_urom_result *res = &schedule->res; + int ret; + + if (res == NULL) { + cl_error(cl_lib, "Error in UROM"); + ctask->status = UCC_ERR_NO_MESSAGE; + return; + } + + ucp_worker_progress(tl_ctx->worker.ucp_worker); + + ret = doca_pe_progress(ctx->urom_ctx.urom_pe); + if (ret == 0 && res->result == DOCA_SUCCESS) { + ctask->status = UCC_INPROGRESS; + return; + } + + if (res->result != DOCA_SUCCESS) { + cl_error(&cl_lib->super, "Error in DOCA_UROM, UCC collective task failed"); + } + + ctask->status = res->collective.status; + cl_debug(&cl_lib->super, "completed the collective from urom"); +} + +ucc_status_t ucc_cl_doca_urom_coll_full_init( + ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, + ucc_coll_task_t **task) +{ + ucc_cl_doca_urom_team_t *cl_team = ucc_derived_of(team, + ucc_cl_doca_urom_team_t); + ucc_cl_doca_urom_context_t *ctx = UCC_CL_DOCA_UROM_TEAM_CTX(cl_team); + ucc_cl_doca_urom_lib_t *cl_lib = ucc_derived_of(ctx->super.super.lib, + ucc_cl_doca_urom_lib_t); + ucc_status_t status; + ucc_cl_doca_urom_schedule_t *cl_schedule; + ucc_base_coll_args_t args; + ucc_schedule_t *schedule; + + cl_schedule = ucc_cl_doca_urom_get_schedule(cl_team); + if (ucc_unlikely(!cl_schedule)) { + return UCC_ERR_NO_MEMORY; + } + schedule = &cl_schedule->super.super; + memcpy(&args, coll_args, sizeof(args)); + status = ucc_schedule_init(schedule, &args, team); + if (UCC_OK != status) { + ucc_cl_doca_urom_put_schedule(schedule); + return status; + } + + schedule->super.post = ucc_cl_doca_urom_coll_full_start; + schedule->super.progress = ucc_cl_doca_urom_coll_full_progress; + schedule->super.finalize = ucc_cl_doca_urom_coll_full_finalize; + schedule->super.triggered_post = ucc_triggered_post; + schedule->super.triggered_post_setup = ucc_cl_doca_urom_triggered_post_setup; + + *task = &schedule->super; + cl_debug(cl_lib, "cl doca urom coll initialized"); + return UCC_OK; +} + +ucc_status_t ucc_cl_doca_urom_coll_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task) +{ + ucc_cl_doca_urom_team_t *cl_team = ucc_derived_of(team, + ucc_cl_doca_urom_team_t); + ucc_cl_doca_urom_context_t *ctx = UCC_CL_DOCA_UROM_TEAM_CTX(cl_team); + ucc_cl_doca_urom_lib_t *doca_urom_lib = ucc_derived_of( + ctx->super.super.lib, + ucc_cl_doca_urom_lib_t); + + switch (coll_args->args.coll_type) { + case UCC_COLL_TYPE_ALLREDUCE: + case UCC_COLL_TYPE_ALLGATHER: + case UCC_COLL_TYPE_ALLTOALL: + return ucc_cl_doca_urom_coll_full_init(coll_args, team, task); + default: + cl_error(doca_urom_lib, "coll_type %s is not supported", + ucc_coll_type_str(coll_args->args.coll_type)); + } + + return UCC_OK; +} diff --git a/src/components/cl/doca_urom/cl_doca_urom_coll.h b/src/components/cl/doca_urom/cl_doca_urom_coll.h new file mode 100644 index 0000000000..344c6684da --- /dev/null +++ b/src/components/cl/doca_urom/cl_doca_urom_coll.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ +#ifndef UCC_CL_DOCA_UROM_COLL_H_ +#define UCC_CL_DOCA_UROM_COLL_H_ + +#include "cl_doca_urom.h" +#include "schedule/ucc_schedule_pipelined.h" +#include "components/mc/ucc_mc.h" + +#include "../../tl/ucp/tl_ucp.h" + +#define UCC_CL_DOCA_UROM_N_DEFAULT_ALG_SELECT_STR 2 + +extern const char + *ucc_cl_doca_urom_default_alg_select_str[UCC_CL_DOCA_UROM_N_DEFAULT_ALG_SELECT_STR]; + +typedef struct ucc_cl_doca_urom_schedule_t { + ucc_schedule_pipelined_t super; + struct ucc_cl_doca_urom_result res; + struct export_buf src_ebuf; + struct export_buf dst_ebuf; +} ucc_cl_doca_urom_schedule_t; + +static inline ucc_cl_doca_urom_schedule_t * +ucc_cl_doca_urom_get_schedule(ucc_cl_doca_urom_team_t *team) +{ + ucc_cl_doca_urom_context_t *ctx = UCC_CL_DOCA_UROM_TEAM_CTX(team); + ucc_cl_doca_urom_schedule_t *schedule = ucc_mpool_get(&ctx->sched_mp); + + return schedule; +} + +static inline void ucc_cl_doca_urom_put_schedule(ucc_schedule_t *schedule) +{ + ucc_mpool_put(schedule); +} + +#endif diff --git a/src/components/cl/doca_urom/cl_doca_urom_common.c b/src/components/cl/doca_urom/cl_doca_urom_common.c new file mode 100644 index 0000000000..97f1c929d6 --- /dev/null +++ b/src/components/cl/doca_urom/cl_doca_urom_common.c @@ -0,0 +1,304 @@ +/* + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES, ALL RIGHTS RESERVED. + * + * This software product is a proprietary product of NVIDIA CORPORATION & + * AFFILIATES (the "Company") and all right, title, and interest in and to the + * software product, including all associated intellectual property rights, are + * and shall remain exclusively with the Company. + * + * This software product is governed by the End User License Agreement + * provided with the software product. + * + */ + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "cl_doca_urom_common.h" + +DOCA_LOG_REGISTER(UCC::DOCA_CL : UROM_COMMON); + +doca_error_t ucc_cl_doca_urom_start_urom_service( + struct doca_pe *pe, struct doca_dev *dev, uint64_t nb_workers, + struct doca_urom_service **service) +{ + enum doca_ctx_states state; + struct doca_urom_service *inst; + doca_error_t result, tmp_result; + + /* Create service context */ + result = doca_urom_service_create(&inst); + if (result != DOCA_SUCCESS) + return result; + + result = doca_pe_connect_ctx(pe, doca_urom_service_as_ctx(inst)); + if (result != DOCA_SUCCESS) + goto service_cleanup; + + result = doca_urom_service_set_max_workers(inst, nb_workers); + if (result != DOCA_SUCCESS) + goto service_cleanup; + + result = doca_urom_service_set_dev(inst, dev); + if (result != DOCA_SUCCESS) + goto service_cleanup; + + result = doca_ctx_start(doca_urom_service_as_ctx(inst)); + if (result != DOCA_SUCCESS) + goto service_cleanup; + + result = doca_ctx_get_state(doca_urom_service_as_ctx(inst), &state); + if (result != DOCA_SUCCESS || state != DOCA_CTX_STATE_RUNNING) + goto service_stop; + + *service = inst; + return DOCA_SUCCESS; + +service_stop: + tmp_result = doca_ctx_stop(doca_urom_service_as_ctx(inst)); + if (tmp_result != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to stop UROM service"); + DOCA_ERROR_PROPAGATE(result, tmp_result); + } + +service_cleanup: + tmp_result = doca_urom_service_destroy(inst); + if (tmp_result != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to destroy UROM service"); + DOCA_ERROR_PROPAGATE(result, tmp_result); + } + return result; +} + +doca_error_t ucc_cl_doca_urom_start_urom_worker( + struct doca_pe *pe, struct doca_urom_service *service, + uint64_t worker_id, uint32_t *gid, uint64_t nb_tasks, + doca_cpu_set_t *cpuset, char **env, size_t env_count, uint64_t plugins, + struct doca_urom_worker **worker) +{ + enum doca_ctx_states state; + struct doca_urom_worker *inst; + doca_error_t result, tmp_result; + + result = doca_urom_worker_create(&inst); + if (result != DOCA_SUCCESS) + return result; + + result = doca_urom_worker_set_service(inst, service); + if (result != DOCA_SUCCESS) + goto worker_cleanup; + + result = doca_pe_connect_ctx(pe, doca_urom_worker_as_ctx(inst)); + if (result != DOCA_SUCCESS) + goto worker_cleanup; + + result = doca_urom_worker_set_id(inst, worker_id); + if (result != DOCA_SUCCESS) + goto worker_cleanup; + + if (gid != NULL) { + result = doca_urom_worker_set_gid(inst, *gid); + if (result != DOCA_SUCCESS) + goto worker_cleanup; + } + + if (env != NULL) { + result = doca_urom_worker_set_env(inst, env, env_count); + if (result != DOCA_SUCCESS) + goto worker_cleanup; + } + + result = doca_urom_worker_set_max_inflight_tasks(inst, nb_tasks); + if (result != DOCA_SUCCESS) + goto worker_cleanup; + + result = doca_urom_worker_set_plugins(inst, plugins); + if (result != DOCA_SUCCESS) + goto worker_cleanup; + + if (cpuset != NULL) { + result = doca_urom_worker_set_cpuset(inst, *cpuset); + if (result != DOCA_SUCCESS) + goto worker_cleanup; + } + + result = doca_ctx_start(doca_urom_worker_as_ctx(inst)); + if (result != DOCA_ERROR_IN_PROGRESS) + goto worker_cleanup; + + result = doca_ctx_get_state(doca_urom_worker_as_ctx(inst), &state); + if (result != DOCA_SUCCESS) + goto worker_stop; + + if (state != DOCA_CTX_STATE_STARTING) { + result = DOCA_ERROR_BAD_STATE; + goto worker_stop; + } + + *worker = inst; + return DOCA_SUCCESS; + +worker_stop: + tmp_result = doca_ctx_stop(doca_urom_worker_as_ctx(inst)); + if (tmp_result != DOCA_SUCCESS && tmp_result != DOCA_ERROR_IN_PROGRESS) { + DOCA_LOG_ERR("Failed to request stop UROM worker"); + DOCA_ERROR_PROPAGATE(result, tmp_result); + } + + do { + doca_pe_progress(pe); + doca_ctx_get_state(doca_urom_worker_as_ctx(inst), &state); + } while (state != DOCA_CTX_STATE_IDLE); + +worker_cleanup: + tmp_result = doca_urom_worker_destroy(inst); + if (tmp_result != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to destroy UROM worker"); + DOCA_ERROR_PROPAGATE(result, tmp_result); + } + + return result; +} + +doca_error_t ucc_cl_doca_urom_start_urom_domain( + struct doca_pe *pe, struct doca_urom_domain_oob_coll *oob, + uint64_t *worker_ids, struct doca_urom_worker **workers, + size_t nb_workers, struct ucc_cl_doca_urom_domain_buffer_attrs *buffers, + size_t nb_buffers, struct doca_urom_domain **domain) +{ + struct doca_urom_domain *inst; + enum doca_ctx_states state; + doca_error_t result, tmp_result; + size_t i; + + result = doca_urom_domain_create(&inst); + if (result != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to create domain"); + return result; + } + + result = doca_pe_connect_ctx(pe, doca_urom_domain_as_ctx(inst)); + if (result != DOCA_SUCCESS) + goto domain_destroy; + + result = doca_urom_domain_set_oob(inst, oob); + if (result != DOCA_SUCCESS) + goto domain_destroy; + + result = doca_urom_domain_set_workers(inst, worker_ids, workers, nb_workers); + if (result != DOCA_SUCCESS) + goto domain_destroy; + + if (nb_workers != 0 && buffers != NULL) { + result = doca_urom_domain_set_buffers_count(inst, nb_buffers); + if (result != DOCA_SUCCESS) + goto domain_destroy; + + for (i = 0; i < nb_buffers; i++) { + result = doca_urom_domain_add_buffer(inst, + buffers[i].buffer, + buffers[i].buf_len, + buffers[i].memh, + buffers[i].memh_len, + buffers[i].mkey, + buffers[i].mkey_len); + if (result != DOCA_SUCCESS) + goto domain_destroy; + } + } + + result = doca_ctx_start(doca_urom_domain_as_ctx(inst)); + if (result != DOCA_ERROR_IN_PROGRESS) + goto domain_stop; + + result = doca_ctx_get_state(doca_urom_domain_as_ctx(inst), &state); + if (result != DOCA_SUCCESS) + goto domain_stop; + + if (state != DOCA_CTX_STATE_STARTING) { + result = DOCA_ERROR_BAD_STATE; + goto domain_stop; + } + + *domain = inst; + return DOCA_SUCCESS; + +domain_stop: + tmp_result = doca_ctx_stop(doca_urom_domain_as_ctx(inst)); + if (tmp_result != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to stop UROM domain"); + DOCA_ERROR_PROPAGATE(result, tmp_result); + } + +domain_destroy: + tmp_result = doca_urom_domain_destroy(inst); + if (tmp_result != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to destroy UROM domain"); + DOCA_ERROR_PROPAGATE(result, tmp_result); + } + return result; +} + + +doca_error_t ucc_cl_doca_urom_open_doca_device_with_ibdev_name( + const uint8_t *value, size_t val_size, + ucc_cl_doca_urom_tasks_check func, struct doca_dev **retval) +{ + char buf[DOCA_DEVINFO_IBDEV_NAME_SIZE] = {}; + char val_copy[DOCA_DEVINFO_IBDEV_NAME_SIZE] = {}; + struct doca_devinfo **dev_list; + uint32_t nb_devs; + int res; + size_t i; + + /* Set default return value */ + *retval = NULL; + + /* Setup */ + if (val_size > DOCA_DEVINFO_IBDEV_NAME_SIZE) { + DOCA_LOG_ERR("Value size too large. Failed to locate device"); + return DOCA_ERROR_INVALID_VALUE; + } + memcpy(val_copy, value, val_size); + + res = doca_devinfo_create_list(&dev_list, &nb_devs); + if (res != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to load doca devices list: %s", doca_error_get_descr(res)); + return res; + } + + /* Search */ + for (i = 0; i < nb_devs; i++) { + res = doca_devinfo_get_ibdev_name(dev_list[i], buf, DOCA_DEVINFO_IBDEV_NAME_SIZE); + if (res == DOCA_SUCCESS && strncmp(buf, val_copy, val_size) == 0) { + /* If any special capabilities are needed */ + if (func != NULL && func(dev_list[i]) != DOCA_SUCCESS) + continue; + + /* if device can be opened */ + res = doca_dev_open(dev_list[i], retval); + if (res == DOCA_SUCCESS) { + doca_devinfo_destroy_list(dev_list); + return res; + } + } + } + + DOCA_LOG_WARN("Matching device not found"); + res = DOCA_ERROR_NOT_FOUND; + + doca_devinfo_destroy_list(dev_list); + return res; +} diff --git a/src/components/cl/doca_urom/cl_doca_urom_common.h b/src/components/cl/doca_urom/cl_doca_urom_common.h new file mode 100644 index 0000000000..20335cf7d3 --- /dev/null +++ b/src/components/cl/doca_urom/cl_doca_urom_common.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES, ALL RIGHTS RESERVED. + * + * This software product is a proprietary product of NVIDIA CORPORATION & + * AFFILIATES (the "Company") and all right, title, and interest in and to the + * software product, including all associated intellectual property rights, are + * and shall remain exclusively with the Company. + * + * This software product is governed by the End User License Agreement + * provided with the software product. + * + */ + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#ifndef UCC_CL_DOCA_UROM_COMMON_H_ +#define UCC_CL_DOCA_UROM_COMMON_H_ + +#include +#include +#include +#include + +/* Function to check if a given device is capable of executing some task */ +typedef doca_error_t (*ucc_cl_doca_urom_tasks_check)(struct doca_devinfo *); + +/* + * Struct contains domain shared buffer details + */ +struct ucc_cl_doca_urom_domain_buffer_attrs { + void *buffer; /* Buffer address */ + size_t buf_len; /* Buffer length */ + void *memh; /* Buffer packed memory handle */ + size_t memh_len; /* Buffer packed memory handle length */ + void *mkey; /* Buffer packed memory key */ + size_t mkey_len; /* Buffer packed memory key length*/ +}; + +/* + * Open a DOCA device according to a given IB device name + * + * @value [in]: IB device name + * @val_size [in]: input length, in bytes + * @func [in]: pointer to a function that checks if the device have some task capabilities (Ignored if set to NULL) + * @retval [out]: pointer to doca_dev struct, NULL if not found + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_cl_doca_urom_open_doca_device_with_ibdev_name( + const uint8_t *value, + size_t val_size, + ucc_cl_doca_urom_tasks_check func, + struct doca_dev **retval); + +/* + * Start UROM service context + * + * @pe [in]: Progress engine + * @dev [in]: service DOCA device + * @nb_workers [in]: number of workers + * @service [out]: service context + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_cl_doca_urom_start_urom_service(struct doca_pe *pe, + struct doca_dev *dev, + uint64_t nb_workers, + struct doca_urom_service **service); + +/* + * Start UROM worker context + * + * @pe [in]: Progress engine + * @service [in]: service context + * @worker_id [in]: Worker id + * @gid [in]: worker group id (optional attribute) + * @nb_tasks [in]: number of tasks + * @cpuset [in]: worker CPU affinity to set + * @env [in]: worker environment variables array + * @env_count [in]: worker environment variables array size + * @plugins [in]: worker plugins + * @worker [out]: set worker context + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_cl_doca_urom_start_urom_worker(struct doca_pe *pe, + struct doca_urom_service *service, + uint64_t worker_id, + uint32_t *gid, + uint64_t nb_tasks, + doca_cpu_set_t *cpuset, + char **env, + size_t env_count, + uint64_t plugins, + struct doca_urom_worker **worker); + +/* + * Start UROM domain context + * + * @pe [in]: Progress engine + * @oob [in]: OOB allgather operations + * @worker_ids [in]: workers ids participate in domain + * @workers [in]: workers participate in domain + * @nb_workers [in]: number of workers in domain + * @buffers [in]: shared buffers + * @nb_buffers [out]: number of shared buffers + * @domain [out]: domain context + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_cl_doca_urom_start_urom_domain(struct doca_pe *pe, + struct doca_urom_domain_oob_coll *oob, + uint64_t *worker_ids, + struct doca_urom_worker **workers, + size_t nb_workers, + struct ucc_cl_doca_urom_domain_buffer_attrs *buffers, + size_t nb_buffers, + struct doca_urom_domain **domain); + +#endif /* UCC_CL_DOCA_UROM_COMMON_H_ */ diff --git a/src/components/cl/doca_urom/cl_doca_urom_context.c b/src/components/cl/doca_urom/cl_doca_urom_context.c new file mode 100644 index 0000000000..5497b41184 --- /dev/null +++ b/src/components/cl/doca_urom/cl_doca_urom_context.c @@ -0,0 +1,512 @@ +/** + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "cl_doca_urom.h" +#include "cl_doca_urom_coll.h" +#include "utils/ucc_malloc.h" + +// Convert the ucc oob allgather test to work with doca_error_t. +// The problem this solves is that DOCA_ERROR_IN_PROGRESS is numerically +// equivalent to 26 while UCC_INPROGRESS is equal to 1 +ucc_status_t (*params_oob_allgather_test)(void *req); +static doca_error_t oob_allgather_test_docafied(void *req) +{ + ucc_status_t ucc_status = params_oob_allgather_test(req); + return ucc_status == UCC_OK ? DOCA_SUCCESS : DOCA_ERROR_IN_PROGRESS; +} + +ucc_status_t (*params_oob_allgather_free)(void *req); +static doca_error_t oob_allgather_free_docafied(void *req) +{ + params_oob_allgather_free(req); + return DOCA_SUCCESS; +} + +ucc_status_t (*params_oob_allgather)(void *, void *, size_t, void *, void **); +static doca_error_t oob_allgather_docafied(void * s, void * r, size_t z, + void * i, void **req_p) +{ + params_oob_allgather(s,r,z,i,req_p); + return DOCA_SUCCESS; +} + +UCC_CLASS_INIT_FUNC(ucc_cl_doca_urom_context_t, + const ucc_base_context_params_t *params, + const ucc_base_config_t *config) +{ + struct ucc_cl_doca_urom_domain_buffer_attrs buf_attrs = {0}; + struct doca_urom_domain_oob_coll oob_coll = {0}; + doca_error_t tmp_result = DOCA_SUCCESS; + union doca_data cookie = {0}; + struct ucc_cl_doca_urom_result res = {0}; + doca_error_t result = DOCA_SUCCESS; + size_t length = 4096; + int ucp_index = -1; + int num_envs = 0; + char **envs = NULL; + const ucc_cl_doca_urom_context_config_t *cl_config = + ucc_derived_of(config, ucc_cl_doca_urom_context_config_t); + ucc_cl_doca_urom_lib_t *doca_urom_lib = + ucc_derived_of(cl_config->super.cl_lib, ucc_cl_doca_urom_lib_t); + ucc_config_names_array_t *tls = + &cl_config->super.cl_lib->tls.array; + ucc_lib_params_t lib_params = { + .mask = UCC_LIB_PARAM_FIELD_THREAD_MODE, + .thread_mode = UCC_THREAD_SINGLE, + }; + ucc_tl_ucp_context_t *tl_ctx; + enum doca_ctx_states state; + struct export_buf ebuf; + ucc_status_t status; + ucs_status_t ucs_status; + ucc_rank_t rank; + uint64_t rank_u64; + void *buffer; + int ret; + char *plugin_name; + char *device; + const struct doca_urom_service_plugin_info *plugins; + size_t plugins_count = 0; + size_t i; + struct doca_log_backend *sdk_log = NULL; + + UCC_CLASS_CALL_SUPER_INIT(ucc_cl_context_t, &cl_config->super, + params->context); + memcpy(&self->cfg, cl_config, sizeof(*cl_config)); + + if (tls->count == 1 && !strcmp(tls->names[0], "all")) { + tls = ¶ms->context->all_tls; + } + self->super.tl_ctxs = ucc_malloc(sizeof(ucc_tl_context_t*) * tls->count, + "cl_doca_urom_tl_ctxs"); + if (!self->super.tl_ctxs) { + cl_error(cl_config->super.cl_lib, + "failed to allocate %zd bytes for tl_ctxs", + sizeof(ucc_tl_context_t**) * tls->count); + return UCC_ERR_NO_MEMORY; + } + self->super.n_tl_ctxs = 0; + for (i = 0; i < tls->count; i++) { + ucc_debug("TL NAME[%zu]: %s", i, tls->names[i]); + status = ucc_tl_context_get(params->context, tls->names[i], + &self->super.tl_ctxs[self->super.n_tl_ctxs]); + if (UCC_OK != status) { + cl_debug(cl_config->super.cl_lib, + "TL %s context is not available, skipping", tls->names[i]); + } else { + if (strcmp(tls->names[i], "ucp") == 0) { + ucp_index = self->super.n_tl_ctxs; + doca_urom_lib->tl_ucp_index = ucp_index; + } + self->super.n_tl_ctxs++; + } + } + if (0 == self->super.n_tl_ctxs) { + cl_error(cl_config->super.cl_lib, "no TL contexts are available"); + ucc_free(self->super.tl_ctxs); + self->super.tl_ctxs = NULL; + return UCC_ERR_NOT_FOUND; + } + + tl_ctx = ucc_derived_of(self->super.tl_ctxs[ucp_index], + ucc_tl_ucp_context_t); + self->ucp_context = tl_ctx->worker.ucp_context; + + memset(&self->urom_ctx, 0, sizeof(ucc_cl_doca_urom_ctx_t)); + + self->urom_ctx.ctx_rank = params->params.oob.oob_ep; + rank = self->urom_ctx.ctx_rank; + + if (self->cfg.plugin_envs.count > 0) { + num_envs = self->cfg.plugin_envs.count; + envs = self->cfg.plugin_envs.names; + } + + plugin_name = self->cfg.plugin_name; + device = self->cfg.device; + + result = doca_log_backend_create_with_file_sdk(stderr, &sdk_log); + if (result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to create DOCA log backend\n"); + return UCC_ERR_NO_RESOURCE; + } + result = doca_log_backend_set_sdk_level(sdk_log, cl_config->doca_log_level); + if (result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to set backend sdk level\n"); + return UCC_ERR_NO_RESOURCE; + } + + result = ucc_cl_doca_urom_open_doca_device_with_ibdev_name((uint8_t *)device, strlen(device), + NULL, &self->urom_ctx.dev); + if (result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "failed to open device %s\n", device); + return UCC_ERR_NO_RESOURCE; + } + + result = doca_pe_create(&self->urom_ctx.urom_pe); + if (result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to create DOCA PE\n"); + goto dev_close; + } + + result = ucc_cl_doca_urom_start_urom_service(self->urom_ctx.urom_pe, self->urom_ctx.dev, 2, + &self->urom_ctx.urom_service); + if (result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to create UROM service context"); + goto pe_destroy; + } + + result = doca_urom_service_get_plugins_list(self->urom_ctx.urom_service, + &plugins, &plugins_count); + if (result != DOCA_SUCCESS || plugins_count == 0) { + cl_error(cl_config->super.cl_lib, + "Failed to get UROM plugins list. plugins_count: %ld\n", + plugins_count); + goto service_stop; + } + + for (i = 0; i < plugins_count; i++) { + if (strcmp(plugin_name, plugins[i].plugin_name) == 0) { + self->urom_ctx.ucc_info = &plugins[i]; + break; + } + } + + if (self->urom_ctx.ucc_info == NULL) { + cl_error(cl_config->super.cl_lib, "Failed to match UCC plugin"); + result = DOCA_ERROR_INVALID_VALUE; + goto service_stop; + } + + result = ucc_cl_doca_urom_save_plugin_id(self->urom_ctx.ucc_info->id, + self->urom_ctx.ucc_info->version); + if (result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to init UCC worker plugin"); + goto service_stop; + } + + self->urom_ctx.urom_worker_addr = ucc_calloc(1, UCC_CL_DOCA_UROM_ADDR_MAX_LEN, + "doca_urom worker addr"); + if (!self->urom_ctx.urom_worker_addr) { + cl_error(cl_config->super.cl_lib, "failed to allocate %d bytes", + UCC_CL_DOCA_UROM_ADDR_MAX_LEN); + return UCC_ERR_NO_MEMORY; + } + + /* Create and start worker context */ + result = ucc_cl_doca_urom_start_urom_worker(self->urom_ctx.urom_pe, + self->urom_ctx.urom_service, rank, NULL, + 16, NULL, envs, num_envs, + self->urom_ctx.ucc_info->id, + &self->urom_ctx.urom_worker); + if (result != DOCA_SUCCESS) + cl_error(cl_config->super.cl_lib, "Failed to start urom worker"); + + /* Loop till worker state changes to running */ + do { + doca_pe_progress(self->urom_ctx.urom_pe); + result = doca_ctx_get_state( + doca_urom_worker_as_ctx(self->urom_ctx.urom_worker), + &state); + } while (state == DOCA_CTX_STATE_STARTING && result == DOCA_SUCCESS); + if (state != DOCA_CTX_STATE_RUNNING || result != DOCA_SUCCESS) { + goto worker_cleanup; + } + + /* Start the UROM domain */ + buffer = calloc(1, length); + if (buffer == NULL) { + cl_error(cl_config->super.cl_lib, + "Failed to allocate urom domain buffer"); + result = DOCA_ERROR_NO_MEMORY; + goto worker_cleanup; + } + + params_oob_allgather = params->params.oob.allgather; + oob_coll.allgather = oob_allgather_docafied; + params_oob_allgather_test = params->params.oob.req_test; + oob_coll.req_test = oob_allgather_test_docafied; + params_oob_allgather_free = params->params.oob.req_free; + oob_coll.req_free = oob_allgather_free_docafied; + oob_coll.coll_info = params->params.oob.coll_info; + oob_coll.n_oob_indexes = params->params.oob.n_oob_eps; + oob_coll.oob_index = rank; + + ucs_status = ucp_worker_get_address(tl_ctx->worker.ucp_worker, + &tl_ctx->worker.worker_address, + &tl_ctx->worker.ucp_addrlen); + if (ucs_status != UCS_OK) { + cl_error(cl_config->super.cl_lib, "Failed to get ucp worker address"); + goto worker_cleanup; + } + + result = (doca_error_t) ucc_cl_doca_urom_buffer_export_ucc( + self->ucp_context, buffer, length, &ebuf); + if (result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to export buffer"); + goto worker_cleanup; + } + + buf_attrs.buffer = buffer; + buf_attrs.buf_len = length; + buf_attrs.memh = ebuf.packed_memh; + buf_attrs.memh_len = ebuf.packed_memh_len; + buf_attrs.mkey = ebuf.packed_key; + buf_attrs.mkey_len = ebuf.packed_key_len; + + /* Create domain context */ + rank_u64 = (uint64_t)rank; + result = ucc_cl_doca_urom_start_urom_domain(self->urom_ctx.urom_pe, &oob_coll, + &rank_u64, &self->urom_ctx.urom_worker, + 1, &buf_attrs, 1, + &self->urom_ctx.urom_domain); + if (result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to start domain"); + goto worker_unmap; + } + + /* Loop till domain state changes to running */ + do { + doca_pe_progress(self->urom_ctx.urom_pe); + result = doca_ctx_get_state( + doca_urom_domain_as_ctx( + self->urom_ctx.urom_domain), + &state); + } while (state == DOCA_CTX_STATE_STARTING && result == DOCA_SUCCESS); + + if (state != DOCA_CTX_STATE_RUNNING || result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to start domain"); + result = DOCA_ERROR_BAD_STATE; + goto worker_unmap; + } + + /* Create lib */ + cookie.ptr = &res; + result = ucc_cl_doca_urom_task_lib_create(self->urom_ctx.urom_worker, + cookie, rank, &lib_params, + ucc_cl_doca_urom_lib_create_finished); + if (result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to create lib creation task"); + goto domain_stop; + } + do { + ret = doca_pe_progress(self->urom_ctx.urom_pe); + } while (ret == 0 && res.result == DOCA_SUCCESS); + + if (res.result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to finish lib create task"); + result = res.result; + goto domain_stop; + } + cl_debug(cl_config->super.cl_lib, "UCC lib create is done"); + + cl_debug(cl_config->super.cl_lib, "Creating pd channel"); + result = ucc_cl_doca_urom_task_pd_channel(self->urom_ctx.urom_worker, + cookie, + rank, + tl_ctx->worker.worker_address, + tl_ctx->worker.ucp_addrlen, + ucc_cl_doca_urom_pss_dc_finished); + if (result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to create data channel task"); + goto lib_destroy; + } + + do { + ret = doca_pe_progress(self->urom_ctx.urom_pe); + } while (ret == 0 && res.result == DOCA_SUCCESS); + + if (res.result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Passive data channel task failed"); + result = res.result; + goto lib_destroy; + } + cl_debug(cl_config->super.cl_lib, "Passive data channel is done"); + + cl_debug(cl_config->super.cl_lib, "Creating task ctx"); + result = ucc_cl_doca_urom_task_ctx_create(self->urom_ctx.urom_worker, + cookie, rank, 0, NULL, 1, + params->params.oob.n_oob_eps, 0x0, + length, + ucc_cl_doca_urom_ctx_create_finished); + if (result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to create UCC context task"); + goto lib_destroy; + } + + do { + ret = doca_pe_progress(self->urom_ctx.urom_pe); + } while (ret == 0 && res.result == DOCA_SUCCESS); + + if (res.result != DOCA_SUCCESS || res.context_create.context == NULL) { + cl_error(cl_config->super.cl_lib, "UCC context create task failed"); + result = res.result; + goto lib_destroy; + } + cl_debug(cl_config->super.cl_lib, + "UCC context create is done, ucc_context: %p", + res.context_create.context); + self->urom_ctx.urom_ucc_context = res.context_create.context; + + status = ucc_mpool_init(&self->sched_mp, 0, + sizeof(ucc_cl_doca_urom_schedule_t), + 0, UCC_CACHE_LINE_SIZE, 2, UINT_MAX, + &ucc_coll_task_mpool_ops, params->thread_mode, + "cl_doca_urom_sched_mp"); + if (UCC_OK != status) { + cl_error(cl_config->super.cl_lib, + "failed to initialize cl_doca_urom_sched mpool"); + goto lib_destroy; + } + + cl_debug(cl_config->super.cl_lib, "initialized cl context: %p", self); + return UCC_OK; + +lib_destroy: + result = ucc_cl_doca_urom_task_lib_destroy(self->urom_ctx.urom_worker, + cookie, rank, ucc_cl_doca_urom_lib_destroy_finished); + if (result != DOCA_SUCCESS) { + cl_error(self->super.super.lib, + "Failed to create UCC lib destroy task"); + } + + do { + ret = doca_pe_progress(self->urom_ctx.urom_pe); + } while (ret == 0 && res.result == DOCA_SUCCESS); + + if (res.result != DOCA_SUCCESS) { + cl_error(self->super.super.lib, "UCC lib destroy failed"); + result = res.result; + } + +domain_stop: + result = doca_ctx_stop( + doca_urom_domain_as_ctx(self->urom_ctx.urom_domain)); + if (result != DOCA_SUCCESS) { + cl_error(self->super.super.lib, "Failed to stop UROM domain"); + } + + result = doca_urom_domain_destroy(self->urom_ctx.urom_domain); + if (result != DOCA_SUCCESS) { + cl_error(self->super.super.lib, "Failed to destroy UROM domain"); + } + +worker_unmap: + ucs_status = ucp_mem_unmap(self->ucp_context, ebuf.memh); + if (ucs_status != UCS_OK) { + cl_error(cl_config->super.cl_lib, "Failed to unmap memh"); + } + free(buffer); + +worker_cleanup: + tmp_result = doca_urom_worker_destroy(self->urom_ctx.urom_worker); + if (tmp_result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to destroy UROM worker"); + } + +service_stop: + tmp_result = doca_ctx_stop( + doca_urom_service_as_ctx(self->urom_ctx.urom_service)); + if (tmp_result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to stop UROM service"); + } + tmp_result = doca_urom_service_destroy(self->urom_ctx.urom_service); + if (tmp_result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to destroy UROM service"); + } + +pe_destroy: + tmp_result = doca_pe_destroy(self->urom_ctx.urom_pe); + if (tmp_result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to destroy PE"); + } + +dev_close: + tmp_result = doca_dev_close(self->urom_ctx.dev); + if (tmp_result != DOCA_SUCCESS) { + cl_error(cl_config->super.cl_lib, "Failed to close device"); + } + + return UCC_ERR_NO_MESSAGE; +} + +UCC_CLASS_CLEANUP_FUNC(ucc_cl_doca_urom_context_t) +{ + struct ucc_cl_doca_urom_result res = {0}; + union doca_data cookie = {0}; + doca_error_t result = DOCA_SUCCESS; + int i, ret; + ucc_rank_t rank; + + rank = self->urom_ctx.ctx_rank; + cookie.ptr = &res; + + result = ucc_cl_doca_urom_task_lib_destroy(self->urom_ctx.urom_worker, + cookie, rank, ucc_cl_doca_urom_lib_destroy_finished); + if (result != DOCA_SUCCESS) { + cl_error(self->super.super.lib, + "Failed to create UCC lib destroy task"); + } + + do { + ret = doca_pe_progress(self->urom_ctx.urom_pe); + } while (ret == 0 && res.result == DOCA_SUCCESS); + + if (res.result != DOCA_SUCCESS) { + cl_error(self->super.super.lib, "UCC lib destroy failed"); + result = res.result; + } + + result = doca_ctx_stop( + doca_urom_domain_as_ctx(self->urom_ctx.urom_domain)); + if (result != DOCA_SUCCESS) { + cl_error(self->super.super.lib, "Failed to stop UROM domain"); + } + + result = doca_urom_domain_destroy(self->urom_ctx.urom_domain); + if (result != DOCA_SUCCESS) { + cl_error(self->super.super.lib, "Failed to destroy UROM domain"); + } + + result = doca_ctx_stop( + doca_urom_service_as_ctx(self->urom_ctx.urom_service)); + if (result != DOCA_SUCCESS) { + cl_error(self->super.super.lib, "Failed to stop UROM service"); + } + result = doca_urom_service_destroy(self->urom_ctx.urom_service); + if (result != DOCA_SUCCESS) { + cl_error(self->super.super.lib, "Failed to destroy UROM service"); + } + + result = doca_pe_destroy(self->urom_ctx.urom_pe); + if (result != DOCA_SUCCESS) { + cl_error(self->super.super.lib, "Failed to destroy PE"); + } + + result = doca_dev_close(self->urom_ctx.dev); + if (result != DOCA_SUCCESS) { + cl_error(self->super.super.lib, "Failed to close device"); + } + + cl_debug(self->super.super.lib, "finalizing cl context: %p", self); + for (i = 0; i < self->super.n_tl_ctxs; i++) { + ucc_tl_context_put(self->super.tl_ctxs[i]); + } + ucc_free(self->super.tl_ctxs); +} + +UCC_CLASS_DEFINE(ucc_cl_doca_urom_context_t, ucc_cl_context_t); + +ucc_status_t +ucc_cl_doca_urom_get_context_attr(const ucc_base_context_t *context, + ucc_base_ctx_attr_t *attr) +{ + if (attr->attr.mask & UCC_CONTEXT_ATTR_FIELD_CTX_ADDR_LEN) { + attr->attr.ctx_addr_len = 0; + } + + return UCC_OK; +} diff --git a/src/components/cl/doca_urom/cl_doca_urom_lib.c b/src/components/cl/doca_urom/cl_doca_urom_lib.c new file mode 100644 index 0000000000..356d91b337 --- /dev/null +++ b/src/components/cl/doca_urom/cl_doca_urom_lib.c @@ -0,0 +1,115 @@ +/** + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "cl_doca_urom.h" +#include "utils/ucc_malloc.h" +#include "components/tl/ucc_tl.h" +#include "core/ucc_global_opts.h" +#include "utils/ucc_math.h" + +#include + +/* NOLINTNEXTLINE TODO params is not used*/ +UCC_CLASS_INIT_FUNC(ucc_cl_doca_urom_lib_t, const ucc_base_lib_params_t *params, + const ucc_base_config_t *config) +{ + const ucc_cl_doca_urom_lib_config_t *cl_config = + ucc_derived_of(config, ucc_cl_doca_urom_lib_config_t); + + UCC_CLASS_CALL_SUPER_INIT(ucc_cl_lib_t, &ucc_cl_doca_urom.super, + &cl_config->super); + memcpy(&self->cfg, cl_config, sizeof(*cl_config)); + + cl_debug(&self->super, "initialized lib object: %p", self); + + return UCC_OK; +} + +UCC_CLASS_CLEANUP_FUNC(ucc_cl_doca_urom_lib_t) +{ + cl_debug(&self->super, "finalizing lib object: %p", self); +} + +UCC_CLASS_DEFINE(ucc_cl_doca_urom_lib_t, ucc_cl_lib_t); +static inline ucc_status_t check_tl_lib_attr(const ucc_base_lib_t *lib, + ucc_tl_iface_t *tl_iface, + ucc_cl_lib_attr_t *attr) +{ + ucc_tl_lib_attr_t tl_attr; + ucc_status_t status; + + memset(&tl_attr, 0, sizeof(tl_attr)); + status = tl_iface->lib.get_attr(NULL, &tl_attr.super); + if (UCC_OK != status) { + cl_error(lib, "failed to query tl %s lib attributes", + tl_iface->super.name); + return status; + } + + attr->super.attr.thread_mode = + ucc_min(attr->super.attr.thread_mode, tl_attr.super.attr.thread_mode); + attr->super.attr.coll_types |= tl_attr.super.attr.coll_types; + attr->super.flags |= tl_attr.super.flags; + + return UCC_OK; +} + +ucc_status_t ucc_cl_doca_urom_get_lib_attr(const ucc_base_lib_t *lib, + ucc_base_lib_attr_t *base_attr) +{ + ucc_cl_doca_urom_lib_t *cl_lib = ucc_derived_of(lib, ucc_cl_doca_urom_lib_t); + ucc_cl_lib_attr_t *attr = ucc_derived_of(base_attr, ucc_cl_lib_attr_t); + ucc_config_names_list_t *tls = &cl_lib->super.tls; + ucc_tl_iface_t *tl_iface; + ucc_status_t status; + int i; + + attr->tls = &cl_lib->super.tls.array; + + if (cl_lib->super.tls.requested) { + status = ucc_config_names_array_dup(&cl_lib->super.tls_forced, + &cl_lib->super.tls.array); + if (UCC_OK != status) { + return status; + } + } + + attr->tls_forced = &cl_lib->super.tls_forced; + attr->super.attr.thread_mode = UCC_THREAD_MULTIPLE; + attr->super.attr.coll_types = 0; + attr->super.flags = 0; + + ucc_assert(tls->array.count >= 1); + + for (i = 0; i < tls->array.count; i++) { + /* Check TLs provided in CL_DOCA_UROM_TLS. Not all of them could be + available, check for NULL. */ + tl_iface = + ucc_derived_of(ucc_get_component(&ucc_global_config.tl_framework, + tls->array.names[i]), + ucc_tl_iface_t); + + if (!tl_iface) { + cl_warn(lib, "tl %s is not available", tls->array.names[i]); + continue; + } + + if (UCC_OK != (status = check_tl_lib_attr(lib, tl_iface, attr))) { + return status; + } + } + + return UCC_OK; +} + +ucc_status_t ucc_cl_doca_urom_get_lib_properties(ucc_base_lib_properties_t *prop) +{ + prop->default_team_size = 2; + prop->min_team_size = 2; + prop->max_team_size = UCC_RANK_MAX; + + return UCC_OK; +} diff --git a/src/components/cl/doca_urom/cl_doca_urom_team.c b/src/components/cl/doca_urom/cl_doca_urom_team.c new file mode 100644 index 0000000000..035b04eaa7 --- /dev/null +++ b/src/components/cl/doca_urom/cl_doca_urom_team.c @@ -0,0 +1,163 @@ +/** + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "cl_doca_urom.h" +#include "utils/ucc_malloc.h" +#include "core/ucc_team.h" + +UCC_CLASS_INIT_FUNC(ucc_cl_doca_urom_team_t, ucc_base_context_t *cl_context, + const ucc_base_team_params_t *params) +{ + union doca_data cookie = {0}; + doca_error_t result = DOCA_SUCCESS; + ucc_cl_doca_urom_context_t *ctx = + ucc_derived_of(cl_context, ucc_cl_doca_urom_context_t); + ucc_status_t status; + + UCC_CLASS_CALL_SUPER_INIT(ucc_cl_team_t, &ctx->super, params); + + self->teams = (ucc_team_h **) ucc_malloc( + sizeof(ucc_team_h *) * UCC_CL_DOCA_UROM_MAX_TEAMS); + + if (!self->teams) { + cl_error(cl_context->lib, + "failed to allocate %zd bytes for doca_urom teams", + sizeof(ucc_team_h *) * UCC_CL_DOCA_UROM_MAX_TEAMS); + status = UCC_ERR_NO_MEMORY; + return status; + } + + self->n_teams = 0; + self->score_map = NULL; + + cookie.ptr = &self->res; + + result = ucc_cl_doca_urom_task_team_create(ctx->urom_ctx.urom_worker, + cookie, + ctx->urom_ctx.ctx_rank, + 0 /* start */, + 1 /* stride */, + params->params.oob.n_oob_eps /* size */, + ctx->urom_ctx.urom_ucc_context, + ucc_cl_doca_urom_team_create_finished); + + if (result != DOCA_SUCCESS) { + cl_error(cl_context->lib, "Failed to create UCC team task"); + return UCC_ERR_NO_RESOURCE; + } + + self->res.team_create.status = 1; // set in progress + + cl_debug(cl_context->lib, "posted cl team: %p", self); + return UCC_OK; +} + +UCC_CLASS_CLEANUP_FUNC(ucc_cl_doca_urom_team_t) +{ + cl_debug(self->super.super.context->lib, "finalizing cl team: %p", self); +} + +UCC_CLASS_DEFINE_DELETE_FUNC(ucc_cl_doca_urom_team_t, ucc_base_team_t); +UCC_CLASS_DEFINE(ucc_cl_doca_urom_team_t, ucc_cl_team_t); + +ucc_status_t ucc_cl_doca_urom_team_destroy(ucc_base_team_t *cl_team) +{ + return UCC_OK; +} + +ucc_status_t ucc_cl_doca_urom_team_create_test(ucc_base_team_t *cl_team) +{ + ucc_status_t ucc_status; + int ret; + ucc_cl_doca_urom_team_t *team = + ucc_derived_of(cl_team, ucc_cl_doca_urom_team_t); + ucc_cl_doca_urom_context_t *ctx = + UCC_CL_DOCA_UROM_TEAM_CTX(team); + ucc_memory_type_t mem_types[2] = {UCC_MEMORY_TYPE_HOST, + UCC_MEMORY_TYPE_CUDA}; + struct ucc_cl_doca_urom_team_create_result *team_create = &team->res.team_create; + struct ucc_cl_doca_urom_result res = {0}; + ucc_coll_score_t *score = NULL; + int mt_n = 2; + + ret = doca_pe_progress(ctx->urom_ctx.urom_pe); + if (ret == 0 && res.result == DOCA_SUCCESS) { + return UCC_INPROGRESS; + } + + if (res.result != DOCA_SUCCESS) { + cl_error(ctx->super.super.lib, + "UCC team create task failed: DOCA status %d\n", res.result); + return UCC_ERR_NO_MESSAGE; + } + + if (team_create->status == 2) { // 2=done + team->teams[team->n_teams] = team_create->team; + ++team->n_teams; + ucc_status = ucc_coll_score_build_default( + cl_team, UCC_CL_DOCA_UROM_DEFAULT_SCORE, + ucc_cl_doca_urom_coll_init, + UCC_COLL_TYPE_ALLREDUCE | + UCC_COLL_TYPE_ALLGATHER | + UCC_COLL_TYPE_ALLTOALL, + mem_types, mt_n, &score); + if (UCC_OK != ucc_status) { + return ucc_status; + } + + ucc_status = ucc_coll_score_build_map(score, &team->score_map); + if (UCC_OK != ucc_status) { + cl_error(ctx->super.super.lib, "failed to build score map"); + } + team->score = score; + ucc_coll_score_set(team->score, UCC_CL_DOCA_UROM_DEFAULT_SCORE); + + return UCC_OK; + } + + return UCC_INPROGRESS; // 1=in progress +} + +ucc_status_t ucc_cl_doca_urom_team_get_scores(ucc_base_team_t *cl_team, + ucc_coll_score_t **score) +{ + ucc_cl_doca_urom_team_t *team = ucc_derived_of(cl_team, + ucc_cl_doca_urom_team_t); + ucc_base_context_t *ctx = UCC_CL_TEAM_CTX(team); + ucc_coll_score_team_info_t team_info; + ucc_status_t status; + + status = ucc_coll_score_dup(team->score, score); + if (UCC_OK != status) { + return status; + } + + if (strlen(ctx->score_str) > 0) { + team_info.alg_fn = NULL; + team_info.default_score = UCC_CL_DOCA_UROM_DEFAULT_SCORE; + team_info.init = NULL; + team_info.num_mem_types = 0; + team_info.supported_mem_types = NULL; /* all memory types supported*/ + team_info.supported_colls = UCC_COLL_TYPE_ALL; + team_info.size = UCC_CL_TEAM_SIZE(team); + + status = ucc_coll_score_update_from_str(ctx->score_str, &team_info, + &team->super.super, *score); + + /* If INVALID_PARAM - User provided incorrect input - try to proceed */ + if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) && + (status != UCC_ERR_NOT_SUPPORTED)) { + goto err; + } + } + + return UCC_OK; + +err: + ucc_coll_score_free(*score); + *score = NULL; + return status; +} diff --git a/src/components/cl/doca_urom/cl_doca_urom_worker_ucc.c b/src/components/cl/doca_urom/cl_doca_urom_worker_ucc.c new file mode 100644 index 0000000000..654de4845b --- /dev/null +++ b/src/components/cl/doca_urom/cl_doca_urom_worker_ucc.c @@ -0,0 +1,970 @@ +/* + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES, ALL RIGHTS RESERVED. + * + * This software product is a proprietary product of NVIDIA CORPORATION & + * AFFILIATES (the "Company") and all right, title, and interest in and to the + * software product, including all associated intellectual property rights, are + * and shall remain exclusively with the Company. + * + * This software product is governed by the End User License Agreement + * provided with the software product. + * + */ + +#include +#include + +#include +#include + +#include "cl_doca_urom_worker_ucc.h" + +DOCA_LOG_REGISTER(UCC::DOCA_CL : WORKER_UCC); + +static uint64_t ucc_id; /* UCC plugin id, id is generated by UROM lib and + * will be updated in init function + */ +static uint64_t ucc_version = 0x01; /* UCC plugin host version */ + +/* UCC task metadata */ +struct ucc_cl_doca_urom_task_data { + union doca_data cookie; /* User cookie */ + union { + ucc_cl_doca_urom_lib_create_finished_cb lib_create; /* User lib create task callback */ + ucc_cl_doca_urom_lib_destroy_finished_cb lib_destroy; /* User lib destroy task callback */ + ucc_cl_doca_urom_ctx_create_finished_cb ctx_create; /* User context create task callback */ + ucc_cl_doca_urom_ctx_destroy_finished_cb ctx_destroy; /* User context destroy task callback */ + ucc_cl_doca_urom_team_create_finished_cb team_create; /* User UCC team create task callback */ + ucc_cl_doca_urom_collective_finished_cb collective; /* User UCC collective task callback */ + ucc_cl_doca_urom_pd_channel_finished_cb pd_channel; /* User passive data channel task callback */ + }; +}; + +/* + * UCC notification unpack function + * + * @packed_notif [in]: packed UCC notification buffer + * @ucc_notif [out]: set unpacked UCC notification + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t notif_unpack(void *packed_notif, struct urom_worker_notify_ucc **ucc_notif) +{ + *ucc_notif = packed_notif; + return DOCA_SUCCESS; +} + +/* + * UCC common command's completion callback function + * + * @task [in]: UROM worker task + * @type [in]: UCC task type + */ +static void completion(struct doca_urom_worker_cmd_task *task, enum urom_worker_ucc_notify_type type) +{ + struct urom_worker_notify_ucc notify_error = {0}; + struct urom_worker_notify_ucc *ucc_notify = ¬ify_error; + struct ucc_cl_doca_urom_task_data *task_data; + struct doca_buf *response; + doca_error_t result; + size_t data_len; + + notify_error.notify_type = type; + + task_data = (struct ucc_cl_doca_urom_task_data *)doca_urom_worker_cmd_task_get_user_data(task); + if (task_data == NULL) { + DOCA_LOG_ERR("Failed to get task data buffer"); + goto task_release; + } + + response = doca_urom_worker_cmd_task_get_response(task); + if (response == NULL) { + DOCA_LOG_ERR("Failed to get task response buffer"); + result = DOCA_ERROR_INVALID_VALUE; + goto error_exit; + } + + result = doca_buf_get_data(response, (void **)&ucc_notify); + if (result != DOCA_SUCCESS) + goto error_exit; + + result = notif_unpack((void *)ucc_notify, &ucc_notify); + if (result != DOCA_SUCCESS) + goto error_exit; + + result = doca_buf_get_data_len(response, &data_len); + if (result != DOCA_SUCCESS) { + DOCA_LOG_ERR("Failed to get response data length"); + goto error_exit; + } + + result = doca_task_get_status(doca_urom_worker_cmd_task_as_task(task)); + if (result != DOCA_SUCCESS) + goto error_exit; + + if (data_len != sizeof(*ucc_notify)) { + DOCA_LOG_ERR("Task response data length is different from notification expected length"); + result = DOCA_ERROR_INVALID_VALUE; + goto error_exit; + } + +error_exit: + switch (ucc_notify->notify_type) { + case UROM_WORKER_NOTIFY_UCC_LIB_CREATE_COMPLETE: + (task_data->lib_create)(result, task_data->cookie, ucc_notify->dpu_worker_id); + break; + case UROM_WORKER_NOTIFY_UCC_LIB_DESTROY_COMPLETE: + (task_data->lib_destroy)(result, task_data->cookie, ucc_notify->dpu_worker_id); + break; + case UROM_WORKER_NOTIFY_UCC_CONTEXT_CREATE_COMPLETE: + (task_data->ctx_create)(result, + task_data->cookie, + ucc_notify->dpu_worker_id, + ucc_notify->context_create_nqe.context); + break; + case UROM_WORKER_NOTIFY_UCC_CONTEXT_DESTROY_COMPLETE: + (task_data->ctx_destroy)(result, task_data->cookie, ucc_notify->dpu_worker_id); + break; + case UROM_WORKER_NOTIFY_UCC_TEAM_CREATE_COMPLETE: + (task_data->team_create)(result, + task_data->cookie, + ucc_notify->dpu_worker_id, + ucc_notify->team_create_nqe.team); + break; + case UROM_WORKER_NOTIFY_UCC_COLLECTIVE_COMPLETE: + (task_data->collective)(result, + task_data->cookie, + ucc_notify->dpu_worker_id, + ucc_notify->coll_nqe.status); + break; + case UROM_WORKER_NOTIFY_UCC_PASSIVE_DATA_CHANNEL_COMPLETE: + (task_data->pd_channel)(result, + task_data->cookie, + ucc_notify->dpu_worker_id, + ucc_notify->pass_dc_nqe.status); + break; + default: + DOCA_LOG_ERR("Invalid UCC notification type %lu", ucc_notify->notify_type); + break; + } + +task_release: + result = doca_urom_worker_cmd_task_release(task); + if (result != DOCA_SUCCESS) + DOCA_LOG_ERR("Failed to release worker command task %s", doca_error_get_descr(result)); +} + +/* + * Pack UCC command + * + * @ucc_cmd [in]: ucc command + * @packed_cmd_len [in/out]: packed command buffer size + * @packed_cmd [out]: packed command buffer + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +static doca_error_t cmd_pack(struct urom_worker_ucc_cmd *ucc_cmd, + size_t *packed_cmd_len, + void *packed_cmd) +{ + void *pack_tail = packed_cmd; + void *pack_head; + size_t pack_len; + size_t team_size; + size_t disp_pack_size; + size_t count_pack_size; + ucc_coll_args_t *coll_args; + int is_count_64, is_disp_64; + + pack_len = sizeof(struct urom_worker_ucc_cmd); + if (pack_len > *packed_cmd_len) + return DOCA_ERROR_INITIALIZATION; + + /* Pack base command */ + pack_head = urom_ucc_serialize_next_raw(&pack_tail, void, pack_len); + memcpy(pack_head, ucc_cmd, pack_len); + *packed_cmd_len = pack_len; + + switch (ucc_cmd->cmd_type) { + case UROM_WORKER_CMD_UCC_LIB_CREATE: + pack_len = sizeof(ucc_lib_params_t); + pack_head = urom_ucc_serialize_next_raw(&pack_tail, void, pack_len); + memcpy(pack_head, ucc_cmd->lib_create_cmd.params, pack_len); + *packed_cmd_len += pack_len; + break; + case UROM_WORKER_CMD_UCC_CONTEXT_CREATE: + if (ucc_cmd->context_create_cmd.stride <= 0) { + pack_len = sizeof(int64_t) * ucc_cmd->context_create_cmd.size; + pack_head = urom_ucc_serialize_next_raw(&pack_tail, void, pack_len); + memcpy(pack_head, ucc_cmd->context_create_cmd.array, pack_len); + *packed_cmd_len += pack_len; + } + break; + case UROM_WORKER_CMD_UCC_COLL: + coll_args = ucc_cmd->coll_cmd.coll_args; + pack_len = sizeof(ucc_coll_args_t); + pack_head = urom_ucc_serialize_next_raw(&pack_tail, void, pack_len); + memcpy(pack_head, ucc_cmd->coll_cmd.coll_args, pack_len); + *packed_cmd_len += pack_len; + pack_len = ucc_cmd->coll_cmd.work_buffer_size; + if (pack_len > 0 && ucc_cmd->coll_cmd.work_buffer) { + pack_head = urom_ucc_serialize_next_raw(&pack_tail, void, pack_len); + memcpy(pack_head, ucc_cmd->coll_cmd.work_buffer, pack_len); + *packed_cmd_len += pack_len; + } + + if (coll_args->coll_type == UCC_COLL_TYPE_ALLTOALLV || + coll_args->coll_type == UCC_COLL_TYPE_ALLGATHERV || coll_args->coll_type == UCC_COLL_TYPE_GATHERV || + coll_args->coll_type == UCC_COLL_TYPE_REDUCE_SCATTERV || + coll_args->coll_type == UCC_COLL_TYPE_SCATTERV) { + team_size = ucc_cmd->coll_cmd.team_size; + is_count_64 = ((coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS) && + (coll_args->flags & UCC_COLL_ARGS_FLAG_COUNT_64BIT)); + is_disp_64 = ((coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS) && + (coll_args->flags & UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT)); + count_pack_size = team_size * ((is_count_64) ? sizeof(uint64_t) : sizeof(uint32_t)); + disp_pack_size = team_size * ((is_disp_64) ? sizeof(uint64_t) : sizeof(uint32_t)); + pack_len = count_pack_size; + pack_head = urom_ucc_serialize_next_raw(&pack_tail, void, pack_len); + memcpy(pack_head, coll_args->src.info_v.counts, pack_len); + *packed_cmd_len += pack_len; + pack_head = urom_ucc_serialize_next_raw(&pack_tail, void, pack_len); + memcpy(pack_head, coll_args->dst.info_v.counts, pack_len); + *packed_cmd_len += pack_len; + + pack_len = disp_pack_size; + pack_head = urom_ucc_serialize_next_raw(&pack_tail, void, pack_len); + memcpy(pack_head, coll_args->src.info_v.displacements, pack_len); + *packed_cmd_len += pack_len; + pack_head = urom_ucc_serialize_next_raw(&pack_tail, void, pack_len); + memcpy(pack_head, coll_args->dst.info_v.displacements, pack_len); + *packed_cmd_len += pack_len; + } + break; + case UROM_WORKER_CMD_UCC_CREATE_PASSIVE_DATA_CHANNEL: + pack_len = ucc_cmd->pass_dc_create_cmd.addr_len; + pack_head = urom_ucc_serialize_next_raw(&pack_tail, void, pack_len); + memcpy(pack_head, ucc_cmd->pass_dc_create_cmd.ucp_addr, pack_len); + *packed_cmd_len += pack_len; + break; + } + return DOCA_SUCCESS; +} + +/* + * UCC library create command completion callback function, user callback will be called inside the function + * + * @task [in]: UROM worker task + * @task_user_data [in]: task user data + * @ctx_user_data [in]: worker context user data + */ +static void lib_create_completed(struct doca_urom_worker_cmd_task *task, + union doca_data task_user_data, + union doca_data ctx_user_data) +{ + (void)task_user_data; + (void)ctx_user_data; + completion(task, UROM_WORKER_NOTIFY_UCC_LIB_CREATE_COMPLETE); +} + +doca_error_t ucc_cl_doca_urom_task_lib_create(struct doca_urom_worker *worker_ctx, + union doca_data cookie, + uint64_t dpu_worker_id, + void *params, + ucc_cl_doca_urom_lib_create_finished_cb cb) +{ + size_t pack_len = 0; + struct doca_buf *payload; + struct doca_urom_worker_cmd_task *task; + struct ucc_cl_doca_urom_task_data *task_data; + struct urom_worker_ucc_cmd *ucc_cmd; + doca_error_t result; + + /* Allocate task */ + result = doca_urom_worker_cmd_task_allocate_init(worker_ctx, ucc_id, &task); + if (result != DOCA_SUCCESS) + return result; + + payload = doca_urom_worker_cmd_task_get_payload(task); + result = doca_buf_get_data(payload, (void **)&ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_get_data_len(payload, &pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + /* Populate commands attributes */ + ucc_cmd->cmd_type = UROM_WORKER_CMD_UCC_LIB_CREATE; + ucc_cmd->dpu_worker_id = dpu_worker_id; + ucc_cmd->lib_create_cmd.params = params; + + result = cmd_pack(ucc_cmd, &pack_len, (void *)ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_set_data(payload, ucc_cmd, pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + task_data = (struct ucc_cl_doca_urom_task_data *)doca_urom_worker_cmd_task_get_user_data(task); + task_data->lib_create = cb; + task_data->cookie = cookie; + + doca_urom_worker_cmd_task_set_cb(task, lib_create_completed); + + result = doca_task_submit(doca_urom_worker_cmd_task_as_task(task)); + if (result != DOCA_SUCCESS) + goto task_destroy; + + return DOCA_SUCCESS; + +task_destroy: + doca_urom_worker_cmd_task_release(task); + return result; +} + +/* + * UCC library destroy command completion callback function, user callback will be called inside the function + * + * @task [in]: UROM worker task + * @task_user_data [in]: task user data + * @ctx_user_data [in]: worker context user data + */ +static void lib_destroy_completed(struct doca_urom_worker_cmd_task *task, + union doca_data task_user_data, + union doca_data ctx_user_data) +{ + (void)task_user_data; + (void)ctx_user_data; + completion(task, UROM_WORKER_NOTIFY_UCC_LIB_DESTROY_COMPLETE); +} + +doca_error_t doca_urom_ucc_task_lib_destroy(struct doca_urom_worker *worker_ctx, + union doca_data cookie, + uint64_t dpu_worker_id, + ucc_cl_doca_urom_lib_destroy_finished_cb cb) +{ + size_t pack_len = 0; + struct doca_buf *payload; + struct doca_urom_worker_cmd_task *task; + struct ucc_cl_doca_urom_task_data *task_data; + struct urom_worker_ucc_cmd *ucc_cmd; + doca_error_t result; + + /* Allocate task */ + result = doca_urom_worker_cmd_task_allocate_init(worker_ctx, ucc_id, &task); + if (result != DOCA_SUCCESS) + return result; + + payload = doca_urom_worker_cmd_task_get_payload(task); + result = doca_buf_get_data(payload, (void **)&ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_get_data_len(payload, &pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + /* Populate commands attributes */ + ucc_cmd->cmd_type = UROM_WORKER_CMD_UCC_LIB_DESTROY; + ucc_cmd->dpu_worker_id = dpu_worker_id; + + result = cmd_pack(ucc_cmd, &pack_len, (void *)ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_set_data(payload, ucc_cmd, pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + task_data = (struct ucc_cl_doca_urom_task_data *)doca_urom_worker_cmd_task_get_user_data(task); + task_data->lib_destroy = cb; + task_data->cookie = cookie; + + doca_urom_worker_cmd_task_set_cb(task, lib_destroy_completed); + + result = doca_task_submit(doca_urom_worker_cmd_task_as_task(task)); + if (result != DOCA_SUCCESS) + goto task_destroy; + + return DOCA_SUCCESS; + +task_destroy: + doca_urom_worker_cmd_task_release(task); + return result; +} + +/* + * UCC context create command completion callback function, user callback will be called inside the function + * + * @task [in]: UROM worker task + * @task_user_data [in]: task user data + * @ctx_user_data [in]: worker context user data + */ +static void ctx_create_completed(struct doca_urom_worker_cmd_task *task, + union doca_data task_user_data, + union doca_data ctx_user_data) +{ + (void)task_user_data; + (void)ctx_user_data; + completion(task, UROM_WORKER_NOTIFY_UCC_CONTEXT_CREATE_COMPLETE); +} + +doca_error_t ucc_cl_doca_urom_task_ctx_create(struct doca_urom_worker *worker_ctx, + union doca_data cookie, + uint64_t dpu_worker_id, + int64_t start, + int64_t *array, + int64_t stride, + int64_t size, + void *base_va, + uint64_t len, + ucc_cl_doca_urom_ctx_create_finished_cb cb) +{ + size_t pack_len = 0; + struct ucc_cl_doca_urom_task_data *task_data; + struct doca_urom_worker_cmd_task *task; + struct urom_worker_ucc_cmd *ucc_cmd; + struct doca_buf *payload; + doca_error_t result; + + /* Allocate task */ + result = doca_urom_worker_cmd_task_allocate_init(worker_ctx, ucc_id, &task); + if (result != DOCA_SUCCESS) + return result; + + payload = doca_urom_worker_cmd_task_get_payload(task); + result = doca_buf_get_data(payload, (void **)&ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_get_data_len(payload, &pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + /* Populate commands attributes */ + ucc_cmd->cmd_type = UROM_WORKER_CMD_UCC_CONTEXT_CREATE; + ucc_cmd->dpu_worker_id = dpu_worker_id; + if (array == NULL) + ucc_cmd->context_create_cmd.start = start; + else + ucc_cmd->context_create_cmd.array = array; + + ucc_cmd->context_create_cmd.stride = stride; + ucc_cmd->context_create_cmd.size = size; + ucc_cmd->context_create_cmd.base_va = base_va; + ucc_cmd->context_create_cmd.len = len; + + result = cmd_pack(ucc_cmd, &pack_len, (void *)ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_set_data(payload, ucc_cmd, pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + task_data = (struct ucc_cl_doca_urom_task_data *)doca_urom_worker_cmd_task_get_user_data(task); + task_data->ctx_create = cb; + task_data->cookie = cookie; + + doca_urom_worker_cmd_task_set_cb(task, ctx_create_completed); + + result = doca_task_submit(doca_urom_worker_cmd_task_as_task(task)); + if (result != DOCA_SUCCESS) + goto task_destroy; + + return DOCA_SUCCESS; + +task_destroy: + doca_urom_worker_cmd_task_release(task); + return result; +} + +/* + * UCC context destroy command completion callback function, user callback will be called inside the function + * + * @task [in]: UROM worker task + * @task_user_data [in]: task user data + * @ctx_user_data [in]: worker context user data + */ +static void ctx_destroy_completed(struct doca_urom_worker_cmd_task *task, + union doca_data task_user_data, + union doca_data ctx_user_data) +{ + (void)task_user_data; + (void)ctx_user_data; + completion(task, UROM_WORKER_NOTIFY_UCC_CONTEXT_DESTROY_COMPLETE); +} + +doca_error_t ucc_cl_doca_urom_task_ctx_destroy(struct doca_urom_worker *worker_ctx, + union doca_data cookie, + uint64_t dpu_worker_id, + void *context, + ucc_cl_doca_urom_ctx_destroy_finished_cb cb) +{ + size_t pack_len = 0; + struct ucc_cl_doca_urom_task_data *task_data; + struct doca_urom_worker_cmd_task *task; + struct urom_worker_ucc_cmd *ucc_cmd; + struct doca_buf *payload; + doca_error_t result; + + /* Allocate task */ + result = doca_urom_worker_cmd_task_allocate_init(worker_ctx, ucc_id, &task); + if (result != DOCA_SUCCESS) + return result; + + payload = doca_urom_worker_cmd_task_get_payload(task); + result = doca_buf_get_data(payload, (void **)&ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_get_data_len(payload, &pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + /* Populate commands attributes */ + ucc_cmd->cmd_type = UROM_WORKER_CMD_UCC_CONTEXT_DESTROY; + ucc_cmd->dpu_worker_id = dpu_worker_id; + ucc_cmd->context_destroy_cmd.context_h = context; + + result = cmd_pack(ucc_cmd, &pack_len, (void *)ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_set_data(payload, ucc_cmd, pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + task_data = (struct ucc_cl_doca_urom_task_data *)doca_urom_worker_cmd_task_get_user_data(task); + task_data->ctx_destroy = cb; + task_data->cookie = cookie; + + doca_urom_worker_cmd_task_set_cb(task, ctx_destroy_completed); + + result = doca_task_submit(doca_urom_worker_cmd_task_as_task(task)); + if (result != DOCA_SUCCESS) + goto task_destroy; + + return DOCA_SUCCESS; + +task_destroy: + doca_urom_worker_cmd_task_release(task); + return result; +} + +/* + * UCC team create command completion callback function, user callback will be called inside the function + * + * @task [in]: UROM worker task + * @task_user_data [in]: task user data + * @ctx_user_data [in]: worker context user data + */ +static void team_create_completed(struct doca_urom_worker_cmd_task *task, + union doca_data task_user_data, + union doca_data ctx_user_data) +{ + (void)task_user_data; + (void)ctx_user_data; + completion(task, UROM_WORKER_NOTIFY_UCC_TEAM_CREATE_COMPLETE); +} + +doca_error_t ucc_cl_doca_urom_task_team_create(struct doca_urom_worker *worker_ctx, + union doca_data cookie, + uint64_t dpu_worker_id, + int64_t start, + int64_t stride, + int64_t size, + void *context, + ucc_cl_doca_urom_team_create_finished_cb cb) +{ + size_t pack_len = 0; + struct ucc_cl_doca_urom_task_data *task_data; + struct doca_urom_worker_cmd_task *task; + struct urom_worker_ucc_cmd *ucc_cmd; + struct doca_buf *payload; + doca_error_t result; + + /* Allocate task */ + result = doca_urom_worker_cmd_task_allocate_init(worker_ctx, ucc_id, &task); + if (result != DOCA_SUCCESS) + return result; + + payload = doca_urom_worker_cmd_task_get_payload(task); + result = doca_buf_get_data(payload, (void **)&ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_get_data_len(payload, &pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + /* Populate commands attributes */ + ucc_cmd->cmd_type = UROM_WORKER_CMD_UCC_TEAM_CREATE; + ucc_cmd->dpu_worker_id = dpu_worker_id; + ucc_cmd->team_create_cmd.start = start; + ucc_cmd->team_create_cmd.stride = stride; + ucc_cmd->team_create_cmd.size = size; + ucc_cmd->team_create_cmd.context_h = context; + + result = cmd_pack(ucc_cmd, &pack_len, (void *)ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_set_data(payload, ucc_cmd, pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + task_data = (struct ucc_cl_doca_urom_task_data *)doca_urom_worker_cmd_task_get_user_data(task); + task_data->team_create = cb; + task_data->cookie = cookie; + + doca_urom_worker_cmd_task_set_cb(task, team_create_completed); + + result = doca_task_submit(doca_urom_worker_cmd_task_as_task(task)); + if (result != DOCA_SUCCESS) + goto task_destroy; + + return DOCA_SUCCESS; + +task_destroy: + doca_urom_worker_cmd_task_release(task); + return result; +} + +/* + * UCC collective command completion callback function, user callback will be called inside the function + * + * @task [in]: UROM worker task + * @task_user_data [in]: task user data + * @ctx_user_data [in]: worker context user data + */ +static void collective_completed(struct doca_urom_worker_cmd_task *task, + union doca_data task_user_data, + union doca_data ctx_user_data) +{ + (void)task_user_data; + (void)ctx_user_data; + completion(task, UROM_WORKER_NOTIFY_UCC_COLLECTIVE_COMPLETE); +} + +doca_error_t ucc_cl_doca_urom_task_collective(struct doca_urom_worker *worker_ctx, + union doca_data cookie, + uint64_t dpu_worker_id, + void *coll_args, + void *team, + int use_xgvmi, + void *work_buffer, + size_t work_buffer_size, + size_t team_size, + ucc_cl_doca_urom_collective_finished_cb cb) +{ + size_t pack_len = 0; + struct ucc_cl_doca_urom_task_data *task_data; + struct doca_urom_worker_cmd_task *task; + struct urom_worker_ucc_cmd *ucc_cmd; + struct doca_buf *payload; + doca_error_t result; + + /* Allocate task */ + result = doca_urom_worker_cmd_task_allocate_init(worker_ctx, ucc_id, &task); + if (result != DOCA_SUCCESS) + return result; + + payload = doca_urom_worker_cmd_task_get_payload(task); + result = doca_buf_get_data(payload, (void **)&ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_get_data_len(payload, &pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + /* Populate commands attributes */ + ucc_cmd->cmd_type = UROM_WORKER_CMD_UCC_COLL; + ucc_cmd->dpu_worker_id = dpu_worker_id; + ucc_cmd->coll_cmd.coll_args = coll_args; + ucc_cmd->coll_cmd.team = team; + ucc_cmd->coll_cmd.use_xgvmi = use_xgvmi; + ucc_cmd->coll_cmd.work_buffer = work_buffer; + ucc_cmd->coll_cmd.work_buffer_size = work_buffer_size; + ucc_cmd->coll_cmd.team_size = team_size; + + result = cmd_pack(ucc_cmd, &pack_len, (void *)ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_set_data(payload, ucc_cmd, pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + task_data = (struct ucc_cl_doca_urom_task_data *)doca_urom_worker_cmd_task_get_user_data(task); + task_data->collective = cb; + task_data->cookie = cookie; + + doca_urom_worker_cmd_task_set_cb(task, collective_completed); + + result = doca_task_submit(doca_urom_worker_cmd_task_as_task(task)); + if (result != DOCA_SUCCESS) + goto task_destroy; + + return DOCA_SUCCESS; + +task_destroy: + doca_urom_worker_cmd_task_release(task); + return result; +} + +/* + * UCC passive data channel command completion callback function, user callback will be called inside the function + * + * @task [in]: UROM worker task + * @task_user_data [in]: task user data + * @ctx_user_data [in]: worker context user data + */ +static void pd_channel_completed(struct doca_urom_worker_cmd_task *task, + union doca_data task_user_data, + union doca_data ctx_user_data) +{ + (void)task_user_data; + (void)ctx_user_data; + completion(task, UROM_WORKER_NOTIFY_UCC_PASSIVE_DATA_CHANNEL_COMPLETE); +} + +doca_error_t ucc_cl_doca_urom_task_pd_channel(struct doca_urom_worker *worker_ctx, + union doca_data cookie, + uint64_t dpu_worker_id, + void *ucp_addr, + size_t addr_len, + ucc_cl_doca_urom_pd_channel_finished_cb cb) +{ + size_t pack_len = 0; + struct ucc_cl_doca_urom_task_data *task_data; + struct doca_urom_worker_cmd_task *task; + struct urom_worker_ucc_cmd *ucc_cmd; + struct doca_buf *payload; + doca_error_t result; + + /* Allocate task */ + result = doca_urom_worker_cmd_task_allocate_init(worker_ctx, ucc_id, &task); + if (result != DOCA_SUCCESS) + return result; + + payload = doca_urom_worker_cmd_task_get_payload(task); + result = doca_buf_get_data(payload, (void **)&ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_get_data_len(payload, &pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + /* Populate commands attributes */ + ucc_cmd->cmd_type = UROM_WORKER_CMD_UCC_CREATE_PASSIVE_DATA_CHANNEL; + ucc_cmd->dpu_worker_id = dpu_worker_id; + ucc_cmd->pass_dc_create_cmd.ucp_addr = ucp_addr; + ucc_cmd->pass_dc_create_cmd.addr_len = addr_len; + + result = cmd_pack(ucc_cmd, &pack_len, (void *)ucc_cmd); + if (result != DOCA_SUCCESS) + goto task_destroy; + + result = doca_buf_set_data(payload, ucc_cmd, pack_len); + if (result != DOCA_SUCCESS) + goto task_destroy; + + task_data = (struct ucc_cl_doca_urom_task_data *) + doca_urom_worker_cmd_task_get_user_data(task); + task_data->pd_channel = cb; + task_data->cookie = cookie; + + doca_urom_worker_cmd_task_set_cb(task, pd_channel_completed); + + result = doca_task_submit(doca_urom_worker_cmd_task_as_task(task)); + if (result != DOCA_SUCCESS) + goto task_destroy; + + return DOCA_SUCCESS; + +task_destroy: + doca_urom_worker_cmd_task_release(task); + return result; +} + +doca_error_t ucc_cl_doca_urom_save_plugin_id(uint64_t plugin_id, + uint64_t version) +{ + if (version != ucc_version) + return DOCA_ERROR_UNSUPPORTED_VERSION; + + ucc_id = plugin_id; + return DOCA_SUCCESS; +} + +/* + * UCC lib create callback + * + * @result [in]: task result + * @cookie [in]: program cookie + * @dpu_worker_id [in]: UROM DPU worker id + */ +void ucc_cl_doca_urom_lib_create_finished( + doca_error_t result, union doca_data cookie, uint64_t dpu_worker_id) +{ + struct ucc_cl_doca_urom_result *res = + (struct ucc_cl_doca_urom_result *)cookie.ptr; + if (res == NULL) + return; + + res->dpu_worker_id = dpu_worker_id; + res->result = result; +} + +/* + * UCC passive data channel callback + * + * @result [in]: task result + * @cookie [in]: program cookie + * @dpu_worker_id [in]: UROM DPU worker id + * @status [in]: channel creation status + */ +void ucc_cl_doca_urom_pss_dc_finished( + doca_error_t result, union doca_data cookie, + uint64_t dpu_worker_id, ucc_status_t status) +{ + struct ucc_cl_doca_urom_result *res = + (struct ucc_cl_doca_urom_result *)cookie.ptr; + if (res == NULL) + return; + + res->dpu_worker_id = dpu_worker_id; + res->result = result; + res->pass_dc.status = status; +} + +/* + * UCC lib destroy callback + * + * @result [in]: task result + * @cookie [in]: program cookie + * @dpu_worker_id [in]: UROM DPU worker id + */ +void ucc_cl_doca_urom_lib_destroy_finished( + doca_error_t result, union doca_data cookie, + uint64_t dpu_worker_id) +{ + struct ucc_cl_doca_urom_result *res = + (struct ucc_cl_doca_urom_result *)cookie.ptr; + if (res == NULL) + return; + + res->dpu_worker_id = dpu_worker_id; + res->result = result; +} + +/* + * UCC context create callback + * + * @result [in]: task result + * @cookie [in]: program cookie + * @dpu_worker_id [in]: UROM DPU worker id + * @context [in]: pointer to UCC context + */ +void ucc_cl_doca_urom_ctx_create_finished( + doca_error_t result, union doca_data cookie, + uint64_t dpu_worker_id, void *context) +{ + struct ucc_cl_doca_urom_result *res = + (struct ucc_cl_doca_urom_result *)cookie.ptr; + if (res == NULL) + return; + + res->dpu_worker_id = dpu_worker_id; + res->result = result; + res->context_create.context = context; +} + +/* + * UCC collective callback + * + * @result [in]: task result + * @cookie [in]: program cookie + * @dpu_worker_id [in]: UROM DPU worker id + * @status [in]: collective status + */ +void ucc_cl_doca_urom_collective_finished( + doca_error_t result, union doca_data cookie, + uint64_t dpu_worker_id, ucc_status_t status) +{ + struct ucc_cl_doca_urom_result *res = + (struct ucc_cl_doca_urom_result *)cookie.ptr; + if (res == NULL) + return; + + res->dpu_worker_id = dpu_worker_id; + res->result = result; + res->collective.status = status; +} + +/* + * UCC team create callback + * + * @result [in]: task result + * @cookie [in]: program cookie + * @dpu_worker_id [in]: UROM DPU worker id + * @team [in]: pointer to UCC team + */ +void ucc_cl_doca_urom_team_create_finished( + doca_error_t result, union doca_data cookie, + uint64_t dpu_worker_id, void *team) +{ + struct ucc_cl_doca_urom_result *res = + (struct ucc_cl_doca_urom_result *)cookie.ptr; + if (res == NULL) + return; + + res->dpu_worker_id = dpu_worker_id; + res->result = result; + res->team_create.team = team; + res->team_create.status = 2; // set done +} + +ucc_status_t ucc_cl_doca_urom_buffer_export_ucc( + ucp_context_h ucp_context, void *buf, + size_t len, struct export_buf *ebuf) +{ + ucs_status_t ucs_status; + ucp_mem_map_params_t params; + ucp_memh_pack_params_t pack_params; + + ebuf->ucp_context = ucp_context; + + params.field_mask = + UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH; + params.address = buf; + params.length = len; + + ucs_status = ucp_mem_map(ucp_context, ¶ms, &ebuf->memh); + assert(ucs_status == UCS_OK); + + pack_params.field_mask = UCP_MEMH_PACK_PARAM_FIELD_FLAGS; + pack_params.flags = UCP_MEMH_PACK_FLAG_EXPORT; + + ucs_status = ucp_memh_pack(ebuf->memh, &pack_params, &ebuf->packed_memh, + &ebuf->packed_memh_len); + if (ucs_status != UCS_OK) { + printf("ucp_memh_pack() returned error: %s\n", + ucs_status_string(ucs_status)); + ebuf->packed_memh = NULL; + ebuf->packed_memh_len = 0; + return UCC_ERR_NO_RESOURCE; + } + ucs_status = ucp_rkey_pack(ucp_context, ebuf->memh, &ebuf->packed_key, + &ebuf->packed_key_len); + if (UCS_OK != ucs_status) { + printf("ucp_rkey_pack() returned error: %s\n", + ucs_status_string(ucs_status)); + return UCC_ERR_NO_RESOURCE; + } + + return UCC_OK; +} diff --git a/src/components/cl/doca_urom/cl_doca_urom_worker_ucc.h b/src/components/cl/doca_urom/cl_doca_urom_worker_ucc.h new file mode 100644 index 0000000000..52c42046eb --- /dev/null +++ b/src/components/cl/doca_urom/cl_doca_urom_worker_ucc.h @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES, ALL RIGHTS RESERVED. + * + * This software product is a proprietary product of NVIDIA CORPORATION & + * AFFILIATES (the "Company") and all right, title, and interest in and to the + * software product, including all associated intellectual property rights, are + * and shall remain exclusively with the Company. + * + * This software product is governed by the End User License Agreement + * provided with the software product. + * + */ + +#ifndef UCC_CL_DOCA_UROM_WORKER_UCC_H_ +#define UCC_CL_DOCA_UROM_WORKER_UCC_H_ + +#include + +#include +#include + +#include + +#include "urom_ucc.h" + +struct export_buf { + ucp_context_h ucp_context; + ucp_mem_h memh; + void *packed_memh; + size_t packed_memh_len; + void *packed_key; + size_t packed_key_len; + uint64_t memh_id; +}; + +/* UCC context create result */ +struct ucc_cl_doca_urom_context_create_result { + void *context; /* Pointer to UCC context */ +}; + +/* UCC team create result */ +struct ucc_cl_doca_urom_team_create_result { + void *team; /* Pointer to UCC team */ + int status; /* 0=nothing, 1=team create in progress, 2=team create done */ +}; + +/* UCC collective result */ +struct ucc_cl_doca_urom_collective_result { + ucc_status_t status; /* UCC collective status */ +}; + +/* UCC passive data channel result */ +struct ucc_cl_doca_urom_pass_dc_result { + ucc_status_t status; /* UCC data channel status */ +}; + +/* UCC task result structure */ +struct ucc_cl_doca_urom_result { + doca_error_t result; /* Task result */ + uint64_t dpu_worker_id; /* DPU worker id */ + union { + struct ucc_cl_doca_urom_context_create_result context_create; /* Context create result */ + struct ucc_cl_doca_urom_team_create_result team_create; /* Team create result */ + struct ucc_cl_doca_urom_collective_result collective; /* Collective result */ + struct ucc_cl_doca_urom_pass_dc_result pass_dc; /* Passive data channel result */ + }; +}; + +void ucc_cl_doca_urom_collective_finished( + doca_error_t result, union doca_data cookie, + uint64_t dpu_worker_id, ucc_status_t status); + +ucc_status_t ucc_cl_doca_urom_buffer_export_ucc( + ucp_context_h ucp_context, void *buf, + size_t len, struct export_buf *ebuf); + +/* + * UCC team create callback + * + * @result [in]: task result + * @cookie [in]: program cookie + * @dpu_worker_id [in]: UROM DPU worker id + * @team [in]: pointer to UCC team + */ +void ucc_cl_doca_urom_team_create_finished( + doca_error_t result, union doca_data cookie, + uint64_t dpu_worker_id, void *team); + +/* + * UCC lib create callback + * + * @result [in]: task result + * @cookie [in]: program cookie + * @dpu_worker_id [in]: UROM DPU worker id + */ +void ucc_cl_doca_urom_lib_create_finished( + doca_error_t result, union doca_data cookie, uint64_t dpu_worker_id); + +/* + * UCC passive data channel callback + * + * @result [in]: task result + * @cookie [in]: program cookie + * @dpu_worker_id [in]: UROM DPU worker id + * @status [in]: channel creation status + */ +void ucc_cl_doca_urom_pss_dc_finished( + doca_error_t result, union doca_data cookie, + uint64_t dpu_worker_id, ucc_status_t status); + +/* + * UCC lib destroy callback + * + * @result [in]: task result + * @cookie [in]: program cookie + * @dpu_worker_id [in]: UROM DPU worker id + */ +void ucc_cl_doca_urom_lib_destroy_finished( + doca_error_t result, union doca_data cookie, uint64_t dpu_worker_id); + +/* + * UCC context create callback + * + * @result [in]: task result + * @cookie [in]: program cookie + * @dpu_worker_id [in]: UROM DPU worker id + * @context [in]: pointer to UCC context + */ +void ucc_cl_doca_urom_ctx_create_finished( + doca_error_t result, union doca_data cookie, + uint64_t dpu_worker_id, void *context); + +/* + * UCC lib create task callback function, will be called once the task is finished + * + * @result [in]: task status + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + */ +typedef void (*ucc_cl_doca_urom_lib_create_finished_cb)( + doca_error_t result, union doca_data cookie, uint64_t dpu_worker_id); + +/* + * UCC lib destroy task callback function, will be called once the task is finished + * + * @result [in]: task status + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + */ +typedef void (*ucc_cl_doca_urom_lib_destroy_finished_cb)( + doca_error_t result, union doca_data cookie, uint64_t dpu_worker_id); + +/* + * UCC context create task callback function, will be called once the task is finished + * + * @result [in]: task status + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + * @context [in]: pointer to UCC context + */ +typedef void (*ucc_cl_doca_urom_ctx_create_finished_cb)( + doca_error_t result, union doca_data cookie, + uint64_t dpu_worker_id, void *context); + +/* + * UCC context destroy task callback function, will be called once the task is finished + * + * @result [in]: task status + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + */ +typedef void (*ucc_cl_doca_urom_ctx_destroy_finished_cb)( + doca_error_t result, union doca_data cookie, uint64_t dpu_worker_id); + +/* + * UCC team create task callback function, will be called once the task is finished + * + * @result [in]: task status + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + * @team [in]: pointer to UCC team + */ +typedef void (*ucc_cl_doca_urom_team_create_finished_cb)( + doca_error_t result, union doca_data cookie, + uint64_t dpu_worker_id, void *team); + +/* + * UCC collective task callback function, will be called once the task is finished + * + * @result [in]: task status + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + * @status [in]: UCC status + */ +typedef void (*ucc_cl_doca_urom_collective_finished_cb)( + doca_error_t result, union doca_data cookie, + uint64_t dpu_worker_id, ucc_status_t status); + +/* + * UCC passive data channel task callback function, will be called once the task is finished + * + * @result [in]: task status + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + * @status [in]: UCC status + */ +typedef void (*ucc_cl_doca_urom_pd_channel_finished_cb)( + doca_error_t result, union doca_data cookie, + uint64_t dpu_worker_id, ucc_status_t status); + +/* + * Create UCC library task + * + * @worker_ctx [in]: DOCA UROM worker context + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + * @params [in]: UCC team parameters + * @cb [in]: program callback to call once the task is finished + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_cl_doca_urom_task_lib_create( + struct doca_urom_worker *worker_ctx, union doca_data cookie, + uint64_t dpu_worker_id, void *params, + ucc_cl_doca_urom_lib_create_finished_cb cb); + +/* + * Create UCC library destroy task + * + * @worker_ctx [in]: DOCA UROM worker context + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + * @cb [in]: program callback to call once the task is finished + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_cl_doca_urom_task_lib_destroy( + struct doca_urom_worker *worker_ctx, union doca_data cookie, + uint64_t dpu_worker_id, + ucc_cl_doca_urom_lib_destroy_finished_cb cb); + +/* + * Create UCC context task + * + * @worker_ctx [in]: DOCA UROM worker context + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + * @start [in]: the started index + * @array [in]: array of indexes, set stride to <= 0 if array is used + * @stride [in]: number of strides + * @size [in]: collective context world size + * @base_va [in]: shared buffer address + * @len [in]: buffer length + * @cb [in]: program callback to call once the task is finished + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_cl_doca_urom_task_ctx_create( + struct doca_urom_worker *worker_ctx, union doca_data cookie, + uint64_t dpu_worker_id, int64_t start, int64_t *array, + int64_t stride, int64_t size, void *base_va, uint64_t len, + ucc_cl_doca_urom_ctx_create_finished_cb cb); + +/* + * Create UCC context destroy task + * + * @worker_ctx [in]: DOCA UROM worker context + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + * @context [in]: pointer of UCC context + * @cb [in]: program callback to call once the task is finished + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_cl_doca_urom_task_ctx_destroy( + struct doca_urom_worker *worker_ctx, union doca_data cookie, + uint64_t dpu_worker_id, void *context, + ucc_cl_doca_urom_ctx_destroy_finished_cb cb); + +/* + * Create UCC team task + * + * @worker_ctx [in]: DOCA UROM worker context + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + * @start [in]: team start index + * @stride [in]: number of strides + * @size [in]: stride size + * @context [in]: UCC context + * @cb [in]: program callback to call once the task is finished + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_cl_doca_urom_task_team_create( + struct doca_urom_worker *worker_ctx, union doca_data cookie, + uint64_t dpu_worker_id, int64_t start, int64_t stride, int64_t size, + void *context, ucc_cl_doca_urom_team_create_finished_cb cb); + +/* + * Create UCC collective task + * + * @worker_ctx [in]: DOCA UROM worker context + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + * @coll_args [in]: collective arguments + * @team [in]: UCC team + * @use_xgvmi [in]: if operation uses XGVMI + * @work_buffer [in]: work buffer + * @work_buffer_size [in]: buffer size + * @team_size [in]: team size + * @cb [in]: program callback to call once the task is finished + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_cl_doca_urom_task_collective( + struct doca_urom_worker *worker_ctx, union doca_data cookie, + uint64_t dpu_worker_id, void *coll_args, void *team, int use_xgvmi, + void *work_buffer, size_t work_buffer_size, size_t team_size, + ucc_cl_doca_urom_collective_finished_cb cb); + +/* + * Create UCC passive data channel task + * + * @worker_ctx [in]: DOCA UROM worker context + * @cookie [in]: user cookie + * @dpu_worker_id [in]: UCC DPU worker id + * @ucp_addr [in]: UCP worker address on host + * @addr_len [in]: UCP worker address length + * @cb [in]: program callback to call once the task is finished + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_cl_doca_urom_task_pd_channel( + struct doca_urom_worker *worker_ctx, union doca_data cookie, + uint64_t dpu_worker_id, void *ucp_addr, size_t addr_len, + ucc_cl_doca_urom_pd_channel_finished_cb cb); + +/* + * This method inits UCC plugin. + * + * @plugin_id [in]: UROM plugin ID + * @version [in]: plugin version on DPU side + * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise + */ +doca_error_t ucc_cl_doca_urom_save_plugin_id( + uint64_t plugin_id, uint64_t version); + +#endif /* UCC_CL_DOCA_UROM_WORKER_UCC_H_ */ diff --git a/src/components/cl/ucc_cl.c b/src/components/cl/ucc_cl.c index d21a0ea1a3..a1dd0f518f 100644 --- a/src/components/cl/ucc_cl.c +++ b/src/components/cl/ucc_cl.c @@ -30,10 +30,11 @@ ucc_config_field_t ucc_cl_context_config_table[] = { }; const char *ucc_cl_names[] = { - [UCC_CL_BASIC] = "basic", - [UCC_CL_HIER] = "hier", - [UCC_CL_ALL] = "all", - [UCC_CL_LAST] = NULL + [UCC_CL_BASIC] = "basic", + [UCC_CL_HIER] = "hier", + [UCC_CL_DOCA_UROM] = "doca_urom", + [UCC_CL_ALL] = "all", + [UCC_CL_LAST] = NULL }; UCC_CLASS_INIT_FUNC(ucc_cl_lib_t, ucc_cl_iface_t *cl_iface, diff --git a/src/components/cl/ucc_cl_type.h b/src/components/cl/ucc_cl_type.h index 6e7bd43479..0ff2021bdc 100644 --- a/src/components/cl/ucc_cl_type.h +++ b/src/components/cl/ucc_cl_type.h @@ -11,6 +11,7 @@ typedef enum { UCC_CL_BASIC, UCC_CL_HIER, + UCC_CL_DOCA_UROM, UCC_CL_ALL, UCC_CL_LAST } ucc_cl_type_t; diff --git a/src/components/tl/ucp/tl_ucp_context.c b/src/components/tl/ucp/tl_ucp_context.c index 1c7c49b53f..aa7581203c 100644 --- a/src/components/tl/ucp/tl_ucp_context.c +++ b/src/components/tl/ucp/tl_ucp_context.c @@ -168,7 +168,8 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_TAG_SENDER_MASK | UCP_PARAM_FIELD_NAME; ucp_params.features = UCP_FEATURE_TAG | UCP_FEATURE_AM; if (params->params.mask & UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS) { - ucp_params.features |= UCP_FEATURE_RMA | UCP_FEATURE_AMO64; + ucp_params.features |= UCP_FEATURE_RMA | + UCP_FEATURE_AMO64; } ucp_params.tag_sender_mask = UCC_TL_UCP_TAG_SENDER_MASK; ucp_params.name = "UCC_UCP_CONTEXT";