diff --git a/src/components/tl/mlx5/Makefile.am b/src/components/tl/mlx5/Makefile.am index 793a4fc6dd..51e67510ac 100644 --- a/src/components/tl/mlx5/Makefile.am +++ b/src/components/tl/mlx5/Makefile.am @@ -12,11 +12,15 @@ alltoall = \ alltoall/alltoall_inline.h \ alltoall/alltoall_coll.c -mcast = \ - mcast/tl_mlx5_mcast_context.c \ - mcast/tl_mlx5_mcast.h \ - mcast/tl_mlx5_mcast_coll.c \ - mcast/tl_mlx5_mcast_coll.h \ +mcast = \ + mcast/tl_mlx5_mcast_context.c \ + mcast/tl_mlx5_mcast.h \ + mcast/tl_mlx5_mcast_coll.c \ + mcast/tl_mlx5_mcast_coll.h \ + mcast/tl_mlx5_mcast_rcache.h \ + mcast/tl_mlx5_mcast_rcache.c \ + mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \ + mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \ mcast/tl_mlx5_mcast_team.c sources = \ diff --git a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c new file mode 100644 index 0000000000..c8dca25fc3 --- /dev/null +++ b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c @@ -0,0 +1,141 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "ucc_tl_mlx5_mcast_p2p.h" + +static inline void ucc_tl_mlx5_mcast_p2p_completion_cb(void* context) +{ + ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj = + (ucc_tl_mlx5_mcast_p2p_completion_obj_t *)context; + + ucc_assert(obj != NULL && obj->compl_cb != NULL); + + obj->compl_cb(obj); + + ucc_assert(obj->req != NULL); + + ucc_collective_finalize(obj->req); +} + +void ucc_tl_mlx5_mcast_completion_cb(void* context, ucc_status_t status) //NOLINT +{ + ucc_tl_mlx5_mcast_p2p_completion_cb(context); +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t + len, ucc_rank_t my_team_rank, ucc_rank_t dest, + ucc_team_h team, ucc_context_h ctx, + ucc_coll_callback_t *callback, + ucc_coll_req_h *p2p_req, int is_send) +{ + ucc_status_t status = UCC_OK; + ucc_coll_req_h req = NULL; + ucc_coll_args_t args; + + args.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | + UCC_COLL_ARGS_FIELD_CB; + args.coll_type = UCC_COLL_TYPE_BCAST; + args.src.info.buffer = buf; + args.src.info.count = len; + args.src.info.datatype = UCC_DT_INT8; + args.src.info.mem_type = UCC_MEMORY_TYPE_HOST; + args.root = is_send ? my_team_rank : dest; + args.cb.cb = callback->cb; + args.cb.data = callback->data; + args.active_set.size = 2; + args.active_set.start = my_team_rank; + args.active_set.stride = dest - my_team_rank; + + status = ucc_collective_init(&args, &req, team); + if (ucc_unlikely(UCC_OK != status)) { + tl_error(ctx->lib, "nonblocking p2p init failed"); + return status; + } + + ((ucc_tl_mlx5_mcast_p2p_completion_obj_t *)args.cb.data)->req = req; + + status = ucc_collective_post(req); + if (ucc_unlikely(UCC_OK != status)) { + tl_error(ctx->lib, "nonblocking p2p post failed"); + return status; + } + + *p2p_req = req; + + return status; +} + +static inline ucc_status_t do_send_nb(void *sbuf, size_t len, ucc_rank_t + my_team_rank, ucc_rank_t dest, ucc_team_h team, + ucc_context_h ctx, ucc_coll_callback_t + *callback, ucc_coll_req_h *req) +{ + return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest, + team, ctx, callback, req, 1); +} + +static inline ucc_status_t do_recv_nb(void *rbuf, size_t len, ucc_rank_t + my_team_rank, ucc_rank_t dest, ucc_team_h team, + ucc_context_h ctx, ucc_coll_callback_t + *callback, ucc_coll_req_h *req) +{ + return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest, + team, ctx, callback, req, 0); +} + +ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t + rank, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t + *obj) +{ + ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx = + (ucc_tl_mlx5_mcast_oob_p2p_context_t *)context; + ucc_status_t status = UCC_OK; + ucc_coll_req_h req = NULL; + ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank; + ucc_team_h team = oob_p2p_ctx->base_team; + ucc_context_h ctx = oob_p2p_ctx->base_ctx; + ucc_coll_callback_t callback; + + callback.cb = ucc_tl_mlx5_mcast_completion_cb; + callback.data = obj; + + status = do_send_nb(src, size, my_team_rank, rank, team, ctx, &callback, &req); + + if (status < 0) { + tl_error(ctx->lib, "nonblocking p2p send failed"); + return status; + } + + return status; +} + +ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t + rank, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t + *obj) +{ + ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx = + (ucc_tl_mlx5_mcast_oob_p2p_context_t *)context; + ucc_status_t status = UCC_OK; + ucc_coll_req_h req = NULL; + ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank; + ucc_team_h team = oob_p2p_ctx->base_team; + ucc_context_h ctx = oob_p2p_ctx->base_ctx; + ucc_coll_callback_t callback; + + callback.cb = ucc_tl_mlx5_mcast_completion_cb; + callback.data = obj; + + status = do_recv_nb(dst, size, my_team_rank, rank, team, ctx, &callback, &req); + + if (status < 0) { + tl_error(ctx->lib, "nonblocking p2p recv failed"); + return status; + } + + return status; +} diff --git a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h new file mode 100644 index 0000000000..6e19e59dde --- /dev/null +++ b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h @@ -0,0 +1,18 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include +#include "components/tl/mlx5/mcast/tl_mlx5_mcast.h" + +ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t + rank, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t + *obj); + +ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* dst, size_t size, ucc_rank_t + rank, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t + *obj); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index fa666612d1..d4b643bd87 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -20,9 +20,39 @@ #define UCC_TL_MLX5_MCAST_ENABLE_BLOCKING true +struct ucc_tl_mlx5_mcast_p2p_completion_obj; +typedef int (*ucc_tl_mlx5_mcast_p2p_completion_cb_fn_t)(struct ucc_tl_mlx5_mcast_p2p_completion_obj *obj); +typedef struct ucc_tl_mlx5_mcast_p2p_completion_obj { + ucc_list_link_t super; + ucc_tl_mlx5_mcast_p2p_completion_cb_fn_t compl_cb; + uint64_t data[3]; + ucc_coll_req_h req; +} ucc_tl_mlx5_mcast_p2p_completion_obj_t; + typedef struct mcast_coll_comm_init_spec { } mcast_coll_comm_init_spec_t; +typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg); + +typedef int (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size, + ucc_rank_t rank, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); + + +typedef int (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size, + ucc_rank_t rank, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); + +typedef struct ucc_tl_mlx5_mcast_context_config { + ucc_tl_context_config_t super; + char *dev_list; + int use_rcache; + size_t reg_threshold; + unsigned int rand_seed; + unsigned int uprogress_num_polls; + int context_per_team; +} ucc_tl_mlx5_mcast_context_config_t; + typedef struct ucc_tl_mlx5_mcast_lib { } ucc_tl_mlx5_mcast_lib_t; UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *, @@ -31,12 +61,49 @@ UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *, typedef struct ucc_tl_mlx5_mcast_ctx_params { } ucc_tl_mlx5_mcast_ctx_params_t; -typedef struct mcast_coll_context_t { -} mcast_coll_context_t; - -typedef struct ucc_tl_mlx5_mcast_context_t { +typedef struct ucc_tl_mlx5_mcast_coll_context { + struct ibv_context *ctx; + struct ibv_pd *pd; + char *devname; + int max_qp_wr; + int ib_port; + int pkey_index; + int mtu; + struct rdma_cm_id *id; + struct rdma_event_channel *channel; + ucc_mpool_t compl_objects_mp; + ucc_mpool_t nack_reqs_mp; + ucc_list_link_t pending_nacks_list; + ucc_rcache_t *rcache; + ucc_tl_mlx5_mcast_ctx_params_t params; + ucc_base_lib_t *lib; +} ucc_tl_mlx5_mcast_coll_context_t; + +typedef struct ucc_tl_mlx5_mcast_oob_ctx { + void *ctx; + union { + ucc_oob_coll_t *oob; + ucc_subset_t subset; + }; +} ucc_tl_mlx5_mcast_oob_ctx_t; + +typedef struct ucc_tl_mlx5_mcast_context { + ucc_thread_mode_t tm; + ucc_tl_mlx5_mcast_coll_context_t mcast_context; + ucc_tl_mlx5_mcast_context_config_t cfg; + ucc_mpool_t req_mp; + ucc_tl_mlx5_mcast_oob_ctx_t oob_ctx; } ucc_tl_mlx5_mcast_context_t; +typedef struct ucc_tl_mlx5_mcast_reg { + void *mr; +} ucc_tl_mlx5_mcast_reg_t; + +typedef struct ucc_tl_mlx5_mcast_rcache_region { + ucc_rcache_region_t super; + ucc_tl_mlx5_mcast_reg_t reg; +} ucc_tl_mlx5_mcast_rcache_region_t; + typedef struct mcast_coll_comm { /* Stuff at a per-communicator sort of level */ } mcast_coll_comm_t; @@ -48,6 +115,13 @@ typedef struct ucc_tl_mlx5_mcast_team { typedef struct ucc_tl_mlx5_mcast_coll_req { /* Stuff that has to happen per call */ } ucc_tl_mlx5_mcast_coll_req_t; +typedef struct ucc_tl_mlx5_mcast_oob_p2p_context { + ucc_context_h base_ctx; + ucc_team_h base_team; + ucc_rank_t my_team_rank; + ucc_subset_t subset; +} ucc_tl_mlx5_mcast_oob_p2p_context_t; + #define TASK_TEAM_MCAST(_task) \ (ucc_derived_of((_task)->super.team, ucc_tl_mlx5_mcast_team_t)) #define TASK_CTX_MCAST(_task) \ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index 54cdfb267d..31b0419f9c 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -37,7 +37,7 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) status = mcast_coll_do_bcast(buf, data_size, root, NULL, comm, UCC_TL_MLX5_MCAST_ENABLE_BLOCKING, &task->bcast_mcast.req_handle); - if (UCC_OK != status && UCC_INPROGRESS != status) { + if (status < 0) { tl_error(UCC_TASK_LIB(task), "mcast_coll_do_bcast failed:%d", status); coll_task->status = status; return ucc_task_complete(coll_task); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c new file mode 100644 index 0000000000..c67a2d3179 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c @@ -0,0 +1,156 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_mcast_rcache.h" + +static ucs_status_t ucc_tl_mlx5_mcast_coll_reg_mr(ucc_tl_mlx5_mcast_coll_context_t + *ctx, void *data, size_t data_size, + void **mr) +{ + *mr = ibv_reg_mr(ctx->pd, data, data_size, IBV_ACCESS_LOCAL_WRITE | + IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE); + + tl_trace(ctx->lib, "external memory register: ptr %p, len %zd, mr %p", + data, data_size, (*mr)); + if (!*mr) { + tl_error(ctx->lib, "failed to register MR"); + return UCS_ERR_NO_MEMORY; + } + + return UCS_OK; +} + + +static ucc_status_t ucc_tl_mlx5_mcast_coll_dereg_mr(ucc_tl_mlx5_mcast_coll_context_t + *ctx, void *mr) +{ + if (ucc_unlikely(NULL == mr)) { + tl_debug(ctx->lib, "external memory mr %p already deregistered", mr); + return UCC_OK; + } + + tl_debug(ctx->lib, "external memory deregister: mr %p", mr); + + if (ibv_dereg_mr(mr)) { + tl_error(ctx->lib, "couldn't destroy mr %p", mr); + return UCC_ERR_NO_RESOURCE; + } + + return UCC_OK; +} + +static ucs_status_t +ucc_tl_mlx5_mcast_rcache_mem_reg_cb(void *context, ucs_rcache_t *rcache, //NOLINT + void *arg, ucs_rcache_region_t *rregion, //NOLINT + uint16_t flags) //NOLINT +{ + ucc_tl_mlx5_mcast_rcache_region_t *region; + void *address; + size_t length; + + address = (void*)rregion->super.start; + length = (size_t)(rregion->super.end - rregion->super.start); + region = ucc_derived_of(rregion, ucc_tl_mlx5_mcast_rcache_region_t); + + return ucc_tl_mlx5_mcast_coll_reg_mr((ucc_tl_mlx5_mcast_coll_context_t *)context, + address, length, ®ion->reg.mr); +} + +static void ucc_tl_mlx5_mcast_rcache_mem_dereg_cb(void *context, ucc_rcache_t //NOLINT + *rcache, ucc_rcache_region_t *rregion) //NOLINT +{ + ucc_tl_mlx5_mcast_rcache_region_t *region = ucc_derived_of(rregion, + ucc_tl_mlx5_mcast_rcache_region_t); + + ucc_tl_mlx5_mcast_coll_dereg_mr((ucc_tl_mlx5_mcast_coll_context_t *)context, + region->reg.mr); +} + +static void ucc_tl_mlx5_mcast_rcache_dump_region_cb(void *context, //NOLINT + ucc_rcache_t *rcache, //NOLINT + ucc_rcache_region_t *rregion, //NOLINT + char *buf, //NOLINT + size_t max) //NOLINT +{ + ucc_tl_mlx5_mcast_rcache_region_t *region = ucc_derived_of(rregion, + ucc_tl_mlx5_mcast_rcache_region_t); + + snprintf(buf, max, "bar ptr:%p", region->reg.mr); +} + +ucc_status_t +ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, + void *addr, size_t length, + ucc_tl_mlx5_mcast_reg_t **reg) +{ + ucc_rcache_region_t *rregion; + ucc_tl_mlx5_mcast_rcache_region_t *region; + ucc_status_t status; + ucc_rcache_t *rcache; + + rcache = ctx->rcache; + + ucc_assert(rcache != NULL); + + status = ucc_rcache_get(rcache, (void *)addr, length, NULL, &rregion); + if (ucc_unlikely(UCC_OK != status)) { + tl_error(ctx->lib, "ucc_rcache_get failed"); + return status; + } + + region = ucc_derived_of(rregion, ucc_tl_mlx5_mcast_rcache_region_t); + *reg = ®ion->reg; + + tl_trace(ctx->lib, "memory register mr %p", (*reg)->mr); + + return UCC_OK; +} + +ucc_status_t +ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_reg_t *reg) +{ + ucc_tl_mlx5_mcast_rcache_region_t *region; + ucc_rcache_t *rcache; + + rcache = ctx->rcache; + + if (reg == NULL) { + return UCC_OK; + } + + ucc_assert(rcache != NULL); + tl_trace(ctx->lib, "memory deregister mr %p", reg->mr); + region = ucc_container_of(reg, ucc_tl_mlx5_mcast_rcache_region_t, reg); + ucc_rcache_region_put(rcache, ®ion->super); + + return UCC_OK; +} + +static ucc_rcache_ops_t ucc_rcache_ops = { + .mem_reg = ucc_tl_mlx5_mcast_rcache_mem_reg_cb, + .mem_dereg = ucc_tl_mlx5_mcast_rcache_mem_dereg_cb, + .dump_region = ucc_tl_mlx5_mcast_rcache_dump_region_cb +}; + +ucc_status_t ucc_tl_mlx5_mcast_setup_rcache(ucc_tl_mlx5_mcast_coll_context_t *ctx) +{ + ucc_rcache_params_t rcache_params; + + rcache_params.alignment = 64; + rcache_params.ucm_event_priority = 1000; + rcache_params.max_regions = ULONG_MAX; + rcache_params.max_size = SIZE_MAX; + rcache_params.region_struct_size = sizeof(ucc_tl_mlx5_mcast_rcache_region_t); + rcache_params.max_alignment = ucc_get_page_size(); + rcache_params.ucm_events = UCM_EVENT_VM_UNMAPPED | + UCM_EVENT_MEM_TYPE_FREE; + rcache_params.context = ctx; + rcache_params.ops = &ucc_rcache_ops; + rcache_params.flags = 0; + + return ucc_rcache_create(&rcache_params, "MCAST", &ctx->rcache); +} diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h new file mode 100644 index 0000000000..e1836704ad --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h @@ -0,0 +1,17 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_mcast.h" +#include "utils/ucc_rcache.h" + +ucc_status_t ucc_tl_mlx5_mcast_setup_rcache(ucc_tl_mlx5_mcast_coll_context_t *ctx); + +ucc_status_t ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t + *ctx, void *addr, size_t length, + ucc_tl_mlx5_mcast_reg_t **reg); + +ucc_status_t ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_reg_t *reg);