diff --git a/src/components/tl/mlx5/Makefile.am b/src/components/tl/mlx5/Makefile.am index 3cc6f262b4..51e67510ac 100644 --- a/src/components/tl/mlx5/Makefile.am +++ b/src/components/tl/mlx5/Makefile.am @@ -12,15 +12,15 @@ alltoall = \ alltoall/alltoall_inline.h \ alltoall/alltoall_coll.c -mcast = \ - mcast/tl_mlx5_mcast_context.c \ - mcast/tl_mlx5_mcast.h \ - mcast/tl_mlx5_mcast_coll.c \ - mcast/tl_mlx5_mcast_coll.h \ - mcast/tl_mlx5_mcast_rcache.h \ - mcast/tl_mlx5_mcast_rcache.c \ - mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \ - mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \ +mcast = \ + mcast/tl_mlx5_mcast_context.c \ + mcast/tl_mlx5_mcast.h \ + mcast/tl_mlx5_mcast_coll.c \ + mcast/tl_mlx5_mcast_coll.h \ + mcast/tl_mlx5_mcast_rcache.h \ + mcast/tl_mlx5_mcast_rcache.c \ + mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \ + mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \ mcast/tl_mlx5_mcast_team.c sources = \ diff --git a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c index d7bd52f188..b19c8a1bfc 100644 --- a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c +++ b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c @@ -8,7 +8,8 @@ static inline void ucc_tl_mlx5_mcast_p2p_completion_cb(void* context) { - ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj = (ucc_tl_mlx5_mcast_p2p_completion_obj_t *)context; + ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj = + (ucc_tl_mlx5_mcast_p2p_completion_obj_t *)context; ucc_assert(obj != NULL && obj->compl_cb != NULL); @@ -24,16 +25,23 @@ void ucc_tl_mlx5_mcast_completion_cb(void* context, ucc_status_t status) ucc_tl_mlx5_mcast_p2p_completion_cb(context); } -static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t len, ucc_rank_t my_team_rank, int dest, - ucc_team_h team, ucc_context_h ctx, - ucc_coll_callback_t *callback, - ucc_coll_req_h *p2p_req, int is_send) +static inline ucc_status_t + ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, + size_t len, + ucc_rank_t my_team_rank, + int dest, + ucc_team_h team, + ucc_context_h ctx, + ucc_coll_callback_t *callback, + ucc_coll_req_h *p2p_req, + int is_send) { ucc_status_t status = UCC_OK; ucc_coll_req_h req = NULL; ucc_coll_args_t args; - args.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_CB; + args.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | + UCC_COLL_ARGS_FIELD_CB; args.coll_type = UCC_COLL_TYPE_BCAST; args.src.info.buffer = buf; args.src.info.count = len; @@ -65,17 +73,19 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t l return status; } -static inline ucc_status_t do_send_nb(void *sbuf, size_t len, ucc_rank_t my_team_rank, - int dest, ucc_team_h team, ucc_context_h ctx, - ucc_coll_callback_t *callback, ucc_coll_req_h *req) +static inline ucc_status_t do_send_nb(void *sbuf, size_t len, ucc_rank_t + my_team_rank, int dest, ucc_team_h team, + ucc_context_h ctx, ucc_coll_callback_t + *callback, ucc_coll_req_h *req) { return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest, team, ctx, callback, req, 1); } -static inline ucc_status_t do_recv_nb(void *rbuf, size_t len, ucc_rank_t my_team_rank, - int dest, ucc_team_h team, ucc_context_h ctx, - ucc_coll_callback_t *callback, ucc_coll_req_h *req) +static inline ucc_status_t do_recv_nb(void *rbuf, size_t len, ucc_rank_t + my_team_rank, int dest, ucc_team_h team, + ucc_context_h ctx, ucc_coll_callback_t + *callback, ucc_coll_req_h *req) { return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest, team, ctx, callback, req, 0); @@ -99,10 +109,13 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_progress(void *context) return status; } -ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, int rank, int tag, - void *context, ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) +ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t + rank, int tag, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t + *obj) { - ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx = (ucc_tl_mlx5_mcast_oob_p2p_context_t *)context; + ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx = + (ucc_tl_mlx5_mcast_oob_p2p_context_t *)context; ucc_status_t status = UCC_OK; ucc_coll_req_h req = NULL; ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank; @@ -123,10 +136,13 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, int rank, int return status; } -ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* src, size_t size, int rank, int tag, - void *context, ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) +ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t + rank, int tag, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t + *obj) { - ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx = (ucc_tl_mlx5_mcast_oob_p2p_context_t *)context; + ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx = + (ucc_tl_mlx5_mcast_oob_p2p_context_t *)context; ucc_status_t status = UCC_OK; ucc_coll_req_h req = NULL; ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank; @@ -137,7 +153,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* src, size_t size, int rank, int callback.cb = ucc_tl_mlx5_mcast_completion_cb; callback.data = obj; - status = do_recv_nb(src, size, my_team_rank, rank, team, ctx, &callback, &req); + status = do_recv_nb(dst, size, my_team_rank, rank, team, ctx, &callback, &req); if (status != UCC_INPROGRESS && status != UCC_OK) { tl_error(ctx->lib, "nonblocking p2p recv failed"); diff --git a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h index 8469087764..e6dc616050 100644 --- a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h +++ b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h @@ -7,10 +7,14 @@ #include #include "components/tl/mlx5/mcast/tl_mlx5_mcast.h" -ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, int rank, int tag, void *context, - ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj); +ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t + rank, int tag, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t + *obj); -ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* src, size_t size, int rank, int tag, void *context, - ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj); +ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* dst, size_t size, ucc_rank_t + rank, int tag, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t + *obj); ucc_status_t ucc_tl_mlx5_mcast_p2p_progress(void* ctx); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 96be33a8b0..0d696f91fb 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -36,20 +36,22 @@ typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg); typedef int (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size, int rank, int tag, void *context, - ucc_tl_mlx5_mcast_p2p_completion_obj_t * compl_obj); + ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); typedef int (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size, int rank, int tag, void *context, - ucc_tl_mlx5_mcast_p2p_completion_obj_t * compl_obj); + ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); typedef int (*ucc_tl_mlx5_mcast_p2p_send_fn_t)(void* src, size_t size, int rank, int tag, void *context, - ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t wait_cb, void *wait_arg); + ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t wait_cb, + void *wait_arg); typedef int (*ucc_tl_mlx5_mcast_p2p_recv_fn_t)(void* src, size_t size, int rank, int tag, void *context, - ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t wait_cb, void *wait_arg); + ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t wait_cb, + void *wait_arg); typedef int (*ucc_tl_mlx5_mcast_p2p_progress_fn_t)(void*); @@ -71,7 +73,7 @@ UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *, typedef struct ucc_tl_mlx5_mcast_ctx_params { } ucc_tl_mlx5_mcast_ctx_params_t; -typedef struct ucc_tl_mlx5_mcast_coll_context_t { +typedef struct ucc_tl_mlx5_mcast_coll_context { struct ibv_context *ctx; struct ibv_pd *pd; char *devname; @@ -97,7 +99,7 @@ typedef struct ucc_tl_mlx5_mcast_oob_ctx { }; } ucc_tl_mlx5_mcast_oob_ctx_t; -typedef struct ucc_tl_mlx5_mcast_context_t { +typedef struct ucc_tl_mlx5_mcast_context { ucc_thread_mode_t tm; ucc_tl_mlx5_mcast_coll_context_t mcast_context; ucc_tl_mlx5_mcast_context_config_t cfg; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c index 91b1fe325b..7d2c883aa1 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c @@ -6,16 +6,17 @@ #include "tl_mlx5_mcast_rcache.h" -static ucs_status_t ucc_tl_mlx5_mcast_coll_reg_mr(ucc_tl_mlx5_mcast_coll_context_t *ctx, - void *data, size_t data_size, void **mr) +static ucs_status_t ucc_tl_mlx5_mcast_coll_reg_mr(ucc_tl_mlx5_mcast_coll_context_t + *ctx, void *data, size_t data_size, + void **mr) { *mr = ibv_reg_mr(ctx->pd, data, data_size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE); - tl_debug(ctx->lib, "external memory register: ptr %p, len %zd, mr %p", + tl_trace(ctx->lib, "external memory register: ptr %p, len %zd, mr %p", data, data_size, (*mr)); if (!*mr) { - tl_error(ctx->lib, "Failed to register MR\n"); + tl_error(ctx->lib, "failed to register MR\n"); return UCS_ERR_NO_MEMORY; } @@ -23,12 +24,13 @@ static ucs_status_t ucc_tl_mlx5_mcast_coll_reg_mr(ucc_tl_mlx5_mcast_coll_context } -static ucc_status_t ucc_tl_mlx5_mcast_coll_dereg_mr(ucc_tl_mlx5_mcast_coll_context_t *ctx, void *mr) +static ucc_status_t ucc_tl_mlx5_mcast_coll_dereg_mr(ucc_tl_mlx5_mcast_coll_context_t + *ctx, void *mr) { if(mr != NULL) { tl_debug(ctx->lib, "external memory deregister: mr %p", mr); if (ibv_dereg_mr(mr)) { - tl_error(ctx->lib, "Couldn't destroy mr %p", mr); + tl_error(ctx->lib, "couldn't destroy mr %p", mr); return UCC_ERR_NO_RESOURCE; } } else { @@ -52,8 +54,8 @@ ucc_tl_mlx5_mcast_rcache_mem_reg_cb(void *context, ucs_rcache_t *rcache, length = (size_t)(rregion->super.end - rregion->super.start); region = ucc_derived_of(rregion, ucc_tl_mlx5_mcast_rcache_region_t); - ret = ucc_tl_mlx5_mcast_coll_reg_mr((ucc_tl_mlx5_mcast_coll_context_t *)context, address, - length, ®ion->reg.mr); + ret = ucc_tl_mlx5_mcast_coll_reg_mr((ucc_tl_mlx5_mcast_coll_context_t *)context, + address, length, ®ion->reg.mr); if (UCS_OK != ret) { return ret; } @@ -62,17 +64,20 @@ ucc_tl_mlx5_mcast_rcache_mem_reg_cb(void *context, ucs_rcache_t *rcache, } -static void ucc_tl_mlx5_mcast_rcache_mem_dereg_cb(void *context, ucc_rcache_t *rcache, - ucc_rcache_region_t *rregion) +static void ucc_tl_mlx5_mcast_rcache_mem_dereg_cb(void *context, ucc_rcache_t + *rcache, ucc_rcache_region_t *rregion) { ucc_tl_mlx5_mcast_rcache_region_t *region = ucc_derived_of(rregion, ucc_tl_mlx5_mcast_rcache_region_t); - ucc_tl_mlx5_mcast_coll_dereg_mr((ucc_tl_mlx5_mcast_coll_context_t *)context, region->reg.mr); + ucc_tl_mlx5_mcast_coll_dereg_mr((ucc_tl_mlx5_mcast_coll_context_t *)context, + region->reg.mr); } -static void ucc_tl_mlx5_mcast_rcache_dump_region_cb(void *context, ucc_rcache_t *rcache, - ucc_rcache_region_t *rregion, char *buf, +static void ucc_tl_mlx5_mcast_rcache_dump_region_cb(void *context, + ucc_rcache_t *rcache, + ucc_rcache_region_t *rregion, + char *buf, size_t max) { ucc_tl_mlx5_mcast_rcache_region_t *region = ucc_derived_of(rregion, @@ -82,8 +87,9 @@ static void ucc_tl_mlx5_mcast_rcache_dump_region_cb(void *context, ucc_rcache_t } ucc_status_t -ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, void *addr, - size_t length, ucc_tl_mlx5_mcast_reg_t **reg) +ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, + void *addr, size_t length, + ucc_tl_mlx5_mcast_reg_t **reg) { ucc_rcache_region_t *rregion; ucc_tl_mlx5_mcast_rcache_region_t *region; @@ -94,8 +100,7 @@ ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, void *addr ucc_assert(rcache != NULL); - status = ucc_rcache_get(rcache, (void *)addr, length, NULL, - &rregion); + status = ucc_rcache_get(rcache, (void *)addr, length, NULL, &rregion); if (ucc_unlikely(UCC_OK != status)) { tl_error(ctx->lib, "ucc_rcache_get failed"); return status; @@ -104,13 +109,14 @@ ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, void *addr region = ucc_derived_of(rregion, ucc_tl_mlx5_mcast_rcache_region_t); *reg = ®ion->reg; - tl_debug(ctx->lib, "memory register mr %p", (*reg)->mr); + tl_trace(ctx->lib, "memory register mr %p", (*reg)->mr); return UCC_OK; } ucc_status_t -ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_tl_mlx5_mcast_reg_t *reg) +ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_reg_t *reg) { ucc_tl_mlx5_mcast_rcache_region_t *region; ucc_rcache_t *rcache; @@ -122,7 +128,7 @@ ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_tl_m } ucc_assert(rcache != NULL); - tl_debug(ctx->lib, "memory deregister mr %p", reg->mr); + tl_trace(ctx->lib, "memory deregister mr %p", reg->mr); region = ucc_container_of(reg, ucc_tl_mlx5_mcast_rcache_region_t, reg); ucc_rcache_region_put(rcache, ®ion->super); @@ -145,7 +151,8 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_rcache(ucc_tl_mlx5_mcast_coll_context_t *ct rcache_params.max_size = SIZE_MAX; rcache_params.region_struct_size = sizeof(ucc_tl_mlx5_mcast_rcache_region_t); rcache_params.max_alignment = ucc_get_page_size(); - rcache_params.ucm_events = UCM_EVENT_VM_UNMAPPED | UCM_EVENT_MEM_TYPE_FREE; + rcache_params.ucm_events = UCM_EVENT_VM_UNMAPPED | + UCM_EVENT_MEM_TYPE_FREE; rcache_params.context = ctx; rcache_params.ops = &ucc_rcache_ops; rcache_params.flags = 0; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h index 4ec51df984..71fda190e4 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h @@ -8,8 +8,9 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_rcache(ucc_tl_mlx5_mcast_coll_context_t *ctx); -ucc_status_t ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, void *addr, - size_t length, ucc_tl_mlx5_mcast_reg_t **reg); +ucc_status_t ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t + *ctx, void *addr, size_t length, + ucc_tl_mlx5_mcast_reg_t **reg); ucc_status_t ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_tl_mlx5_mcast_reg_t *reg);