From fa4c3254fae23973914784cfb59629561018cb31 Mon Sep 17 00:00:00 2001 From: Mamzi Bayatpour Date: Mon, 11 Sep 2023 08:51:11 -0700 Subject: [PATCH] TL/MLX5: mcast progress and helper header --- src/components/tl/mlx5/Makefile.am | 2 + src/components/tl/mlx5/mcast/tl_mlx5_mcast.h | 294 +++++++++++++--- .../tl/mlx5/mcast/tl_mlx5_mcast_coll.c | 30 +- .../tl/mlx5/mcast/tl_mlx5_mcast_helper.h | 322 ++++++++++++++++++ .../tl/mlx5/mcast/tl_mlx5_mcast_progress.h | 64 ++++ .../tl/mlx5/mcast/tl_mlx5_mcast_team.c | 2 +- src/components/tl/mlx5/tl_mlx5.h | 20 +- 7 files changed, 669 insertions(+), 65 deletions(-) create mode 100644 src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h create mode 100644 src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h diff --git a/src/components/tl/mlx5/Makefile.am b/src/components/tl/mlx5/Makefile.am index 51e67510ac..11aec4e5b6 100644 --- a/src/components/tl/mlx5/Makefile.am +++ b/src/components/tl/mlx5/Makefile.am @@ -21,6 +21,8 @@ mcast = \ mcast/tl_mlx5_mcast_rcache.c \ mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \ mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \ + mcast/tl_mlx5_mcast_progress.h \ + mcast/tl_mlx5_mcast_helper.h \ mcast/tl_mlx5_mcast_team.c sources = \ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index d4b643bd87..a91eafd3a2 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -17,8 +17,40 @@ #include "components/tl/ucc_tl_log.h" #include "utils/ucc_rcache.h" +#define POLL_PACKED 16 +#define REL_DONE ((void*)-1) +#define NB_POLL 8 +#define NB_POLL_LARGE 32 +#define MULTICAST_QPN 0xFFFFFF +/* default parameters during modify QP */ +#define DEF_QKEY 0x1a1a1a1a +#define DEF_PKEY 0xffff +#define DEF_PSN 0 +#define DEF_SL 0 +#define DEF_SRC_PATH_BITS 0 +#define GRH_LENGTH 40 +#define DROP_THRESHOLD 1000000 +#define MAX_COMM_POW2 32 -#define UCC_TL_MLX5_MCAST_ENABLE_BLOCKING true +enum { + MCAST_PROTO_EAGER, /* Internal staging buffers */ + MCAST_PROTO_ZCOPY +}; + +enum { + MCAST_P2P_NACK, + MCAST_P2P_ACK, + MCAST_P2P_NEED_NACK_SEND +}; + +enum { + MCAST_RECV_WR = 1, + MCAST_WAIT_RECV_WR, + MCAST_SEND_WR, + MCAST_CALC_WR, + MCAST_BCASTRECV_WR, + MCAST_BCASTSEND_WR, +}; struct ucc_tl_mlx5_mcast_p2p_completion_obj; typedef int (*ucc_tl_mlx5_mcast_p2p_completion_cb_fn_t)(struct ucc_tl_mlx5_mcast_p2p_completion_obj *obj); @@ -29,9 +61,6 @@ typedef struct ucc_tl_mlx5_mcast_p2p_completion_obj { ucc_coll_req_h req; } ucc_tl_mlx5_mcast_p2p_completion_obj_t; -typedef struct mcast_coll_comm_init_spec { -} mcast_coll_comm_init_spec_t; - typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg); typedef int (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size, @@ -43,6 +72,25 @@ typedef int (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size, ucc_rank_t rank, void *context, ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); +typedef struct ucc_tl_mlx5_mcast_p2p_interface { + ucc_tl_mlx5_mcast_p2p_send_nb_fn_t send_nb; + ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t recv_nb; +} ucc_tl_mlx5_mcast_p2p_interface_t; + +typedef struct mcast_coll_comm_init_spec { + ucc_tl_mlx5_mcast_p2p_interface_t p2p_iface; + int sx_depth; + int rx_depth; + int sx_sge; + int rx_sge; + int sx_inline; + int post_recv_thresh; + int scq_moderation; + int wsize; + int max_eager; + void *oob; +} ucc_tl_mlx5_mcast_coll_comm_init_spec_t; + typedef struct ucc_tl_mlx5_mcast_context_config { ucc_tl_context_config_t super; char *dev_list; @@ -53,12 +101,27 @@ typedef struct ucc_tl_mlx5_mcast_context_config { int context_per_team; } ucc_tl_mlx5_mcast_context_config_t; -typedef struct ucc_tl_mlx5_mcast_lib { -} ucc_tl_mlx5_mcast_lib_t; -UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *, - const ucc_base_config_t *); +typedef struct ucc_tl_mlx5_mcast_oob_ctx { + void *ctx; + union { + ucc_oob_coll_t *oob; + ucc_subset_t subset; + }; +} ucc_tl_mlx5_mcast_oob_ctx_t; + +typedef struct ucc_tl_mlx5_mcast_reg { + void *mr; +} ucc_tl_mlx5_mcast_reg_t; + +typedef struct ucc_tl_mlx5_mcast_rcache_region { + ucc_rcache_region_t super; + ucc_tl_mlx5_mcast_reg_t reg; +} ucc_tl_mlx5_mcast_rcache_region_t; typedef struct ucc_tl_mlx5_mcast_ctx_params { + char *ib_dev_name; + int print_nack_stats; + int timeout; } ucc_tl_mlx5_mcast_ctx_params_t; typedef struct ucc_tl_mlx5_mcast_coll_context { @@ -79,14 +142,6 @@ typedef struct ucc_tl_mlx5_mcast_coll_context { ucc_base_lib_t *lib; } ucc_tl_mlx5_mcast_coll_context_t; -typedef struct ucc_tl_mlx5_mcast_oob_ctx { - void *ctx; - union { - ucc_oob_coll_t *oob; - ucc_subset_t subset; - }; -} ucc_tl_mlx5_mcast_oob_ctx_t; - typedef struct ucc_tl_mlx5_mcast_context { ucc_thread_mode_t tm; ucc_tl_mlx5_mcast_coll_context_t mcast_context; @@ -95,24 +150,145 @@ typedef struct ucc_tl_mlx5_mcast_context { ucc_tl_mlx5_mcast_oob_ctx_t oob_ctx; } ucc_tl_mlx5_mcast_context_t; -typedef struct ucc_tl_mlx5_mcast_reg { - void *mr; -} ucc_tl_mlx5_mcast_reg_t; +struct pp_packet { + ucc_list_link_t super; + uint32_t psn; + int length; + uintptr_t context; + uintptr_t buf; +}; -typedef struct ucc_tl_mlx5_mcast_rcache_region { - ucc_rcache_region_t super; - ucc_tl_mlx5_mcast_reg_t reg; -} ucc_tl_mlx5_mcast_rcache_region_t; +struct mcast_ctx { + struct ibv_qp *qp; + struct ibv_ah *ah; + struct ibv_send_wr swr; + struct ibv_sge ssg; +}; +struct packet { + int type; + ucc_rank_t from; + uint32_t psn; + int comm_id; +}; -typedef struct mcast_coll_comm { /* Stuff at a per-communicator sort of level */ -} mcast_coll_comm_t; +typedef struct ucc_tl_mlx5_mcast_coll_comm { + struct pp_packet dummy_packet; + ucc_tl_mlx5_mcast_coll_context_t *ctx; + ucc_tl_mlx5_mcast_coll_comm_init_spec_t params; + ucc_tl_mlx5_mcast_p2p_interface_t p2p; + int tx; + struct ibv_cq *scq; + struct ibv_cq *rcq; + ucc_rank_t rank; + ucc_rank_t commsize; + char *grh_buf; + struct ibv_mr *grh_mr; + uint16_t mcast_lid; + union ibv_gid mgid; + unsigned max_inline; + size_t max_eager; + int max_per_packet; + int pending_send; + int pending_recv; + struct ibv_mr *pp_mr; + char *pp_buf; + struct pp_packet *pp; + uint32_t psn; + uint32_t last_psn; + uint32_t racks_n; + uint32_t sacks_n; + uint32_t last_acked; + uint32_t naks_n; + uint32_t child_n; + uint32_t parent_n; + int buf_n; + struct packet p2p_pkt[MAX_COMM_POW2]; + struct packet p2p_spkt[MAX_COMM_POW2]; + ucc_list_link_t bpool; + ucc_list_link_t pending_q; + struct mcast_ctx mcast; + int reliable_in_progress; + int recv_drop_packet_in_progress; + struct ibv_recv_wr *call_rwr; + struct ibv_sge *call_rsgs; + uint64_t timer; + int stalled; + int comm_id; + void *p2p_ctx; + ucc_base_lib_t *lib; + struct sockaddr_in6 mcast_addr; + int parents[MAX_COMM_POW2]; + int children[MAX_COMM_POW2]; + int nack_requests; + int nacks_counter; + int n_prep_reliable; + int n_mcast_reliable; + int wsize; + struct pp_packet *r_window[1]; // do not add any new variable after here +} ucc_tl_mlx5_mcast_coll_comm_t; typedef struct ucc_tl_mlx5_mcast_team { - void *mcast_comm; + ucc_tl_mlx5_mcast_context_t *mcast_context; + struct ucc_tl_mlx5_mcast_coll_comm *mcast_comm; + ucc_tl_mlx5_mcast_oob_ctx_t oob_ctx; } ucc_tl_mlx5_mcast_team_t; -typedef struct ucc_tl_mlx5_mcast_coll_req { /* Stuff that has to happen per call */ + +typedef struct ucc_tl_mlx5_mcast_nack_req { + ucc_list_link_t super; + int pkt_id; + ucc_tl_mlx5_mcast_coll_comm_t *comm; +} ucc_tl_mlx5_mcast_nack_req_t; + +#define PSN_IS_IN_RANGE(_psn, _call, _comm) \ + ( \ + ((_psn >= _call->start_psn) && \ + (_psn < _call->start_psn + _call->num_packets) && \ + (_psn >= _comm->last_acked) && \ + (_psn < _comm->last_acked + _comm->wsize)) \ + ) + +#define PSN_TO_RECV_OFFSET(_psn, _call, _comm) \ + ( \ + ((ptrdiff_t)((_psn - _call->start_psn) \ + * (_comm->max_per_packet))) \ + ) + +#define PSN_TO_RECV_LEN(_psn, _call, _comm) \ + ( \ + ((_psn - _call->start_psn + 1) % \ + _call->num_packets == 0 ? _call->last_pkt_len : \ + _comm->max_per_packet) \ + ) + +#define PSN_RECEIVED(_psn, _comm) \ + ( \ + (_comm->r_window[(_psn) % \ + _comm->wsize]->psn == (_psn)) \ + ) + +typedef struct ucc_tl_mlx5_mcast_coll_req { + ucc_tl_mlx5_mcast_coll_comm_t *comm; + size_t length; /* bcast buffer size */ + int proto; + struct ibv_mr *mr; + struct ibv_recv_wr *rwr; + struct ibv_sge *rsgs; + void *rreg; + char *ptr; + int am_root; + ucc_rank_t root; + void **rbufs; + int first_send_psn; + int to_send; + int to_recv; + ucc_rank_t parent; + uint32_t start_psn; + int num_packets; + int last_pkt_len; + int offset; + ucc_memory_type_t buf_mem_type; } ucc_tl_mlx5_mcast_coll_req_t; typedef struct ucc_tl_mlx5_mcast_oob_p2p_context { @@ -122,19 +298,59 @@ typedef struct ucc_tl_mlx5_mcast_oob_p2p_context { ucc_subset_t subset; } ucc_tl_mlx5_mcast_oob_p2p_context_t; -#define TASK_TEAM_MCAST(_task) \ - (ucc_derived_of((_task)->super.team, ucc_tl_mlx5_mcast_team_t)) -#define TASK_CTX_MCAST(_task) \ - (ucc_derived_of((_task)->super.team->context, ucc_tl_mlx5_mcast_context_t)) -#define TASK_LIB_MCAST(_task) \ - (ucc_derived_of((_task)->super.team->context->lib, ucc_tl_mlx5_mcast_lib_t)) -#define TASK_ARGS_MCAST(_task) (_task)->super.bargs.args - -ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, - ucc_tl_mlx5_mcast_team_t **mcast_team, - ucc_tl_mlx5_mcast_context_t *ctx, +static inline struct pp_packet* ucc_tl_mlx5_mcast_buf_get_free(ucc_tl_mlx5_mcast_coll_comm_t* comm) +{ + struct pp_packet* pp; + + pp = ucc_list_extract_head(&comm->bpool, struct pp_packet, super); + + ucc_assert(pp == NULL || pp->context == 0); + + return pp; +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast_coll_comm_t* comm) +{ + struct ibv_recv_wr *bad_wr = NULL; + struct ibv_recv_wr *rwr = comm->call_rwr; + struct ibv_sge *sge = comm->call_rsgs; + struct pp_packet *pp = NULL; + int i; + int count = comm->params.rx_depth - comm->pending_recv; + + if (count <= comm->params.post_recv_thresh) { + return UCC_OK; + } + + for (i = 0; i < count - 1; i++) { + if (NULL == (pp = ucc_tl_mlx5_mcast_buf_get_free(comm))) { + break; + } + + rwr[i].wr_id = ((uint64_t) pp); + rwr[i].next = &rwr[i+1]; + sge[2*i + 1].addr = pp->buf; + + ucc_assert((uint64_t)comm->pp <= rwr[i].wr_id + && ((uint64_t)comm->pp + comm->buf_n * sizeof(struct pp_packet)) > rwr[i].wr_id); + } + if (i > 0) { + rwr[i-1].next = NULL; + if (ibv_post_recv(comm->mcast.qp, &rwr[0], &bad_wr)) { + tl_error(comm->lib, "failed to prepost recvs: errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + comm->pending_recv += i; + } + + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, + ucc_tl_mlx5_mcast_team_t **mcast_team, + ucc_tl_mlx5_mcast_context_t *ctx, const ucc_base_team_params_t *params, - mcast_coll_comm_init_spec_t *mcast_conf); + ucc_tl_mlx5_mcast_coll_comm_init_spec_t *mcast_conf); ucc_status_t ucc_tl_mlx5_mcast_coll_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index 31b0419f9c..b853858d88 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -5,15 +5,15 @@ */ #include "tl_mlx5_coll.h" +#include "tl_mlx5_mcast_helper.h" ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* req /* NOLINT */) { return UCC_ERR_NOT_SUPPORTED; } -ucc_status_t mcast_coll_do_bcast(void* buf, int size, int root, void *mr, /* NOLINT */ - mcast_coll_comm_t *comm, /* NOLINT */ - int is_blocking, /* NOLINT */ +ucc_status_t mcast_coll_do_bcast(void* buf, int size, ucc_rank_t root, void *mr, /* NOLINT */ + ucc_tl_mlx5_mcast_coll_comm_t *comm, /* NOLINT */ ucc_tl_mlx5_mcast_coll_req_t **task_req_handle /* NOLINT */) { return UCC_ERR_NOT_SUPPORTED; @@ -21,22 +21,22 @@ ucc_status_t mcast_coll_do_bcast(void* buf, int size, int root, void *mr, /* NOL ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) { - ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); - ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task); - ucc_tl_mlx5_mcast_team_t *team = mlx5_team->mcast; - ucc_coll_args_t *args = &TASK_ARGS_MCAST(task); - ucc_datatype_t dt = args->src.info.datatype; - size_t count = args->src.info.count; - ucc_rank_t root = args->root; - ucc_status_t status = UCC_OK; - size_t data_size = ucc_dt_size(dt) * count; - void *buf = args->src.info.buffer; - mcast_coll_comm_t *comm = team->mcast_comm; + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task); + ucc_tl_mlx5_mcast_team_t *team = mlx5_team->mcast; + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_datatype_t dt = args->src.info.datatype; + size_t count = args->src.info.count; + ucc_rank_t root = args->root; + ucc_status_t status = UCC_OK; + size_t data_size = ucc_dt_size(dt) * count; + void *buf = args->src.info.buffer; + ucc_tl_mlx5_mcast_coll_comm_t *comm = team->mcast_comm; task->bcast_mcast.req_handle = NULL; status = mcast_coll_do_bcast(buf, data_size, root, NULL, comm, - UCC_TL_MLX5_MCAST_ENABLE_BLOCKING, &task->bcast_mcast.req_handle); + &task->bcast_mcast.req_handle); if (status < 0) { tl_error(UCC_TASK_LIB(task), "mcast_coll_do_bcast failed:%d", status); coll_task->status = status; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h new file mode 100644 index 0000000000..0b428868ea --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h @@ -0,0 +1,322 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5.h" + +#ifndef TL_MLX5_MCAST_HELPER_H_ +#define TL_MLX5_MCAST_HELPER_H_ +#include "tl_mlx5_mcast_progress.h" +#include "utils/ucc_math.h" + +/* this function returns the number of signaled sends that + * have been completed or -1 in case of error */ +static inline int ucc_tl_mlx5_mcast_poll_send(ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + struct ibv_wc wc; + int num_comp; + + num_comp = ibv_poll_cq(comm->scq, 1, &wc); + + tl_trace(comm->lib, "polled send completions: %d", num_comp); + + if (num_comp < 0) { + tl_error(comm->lib, "send queue poll completion failed %d", num_comp); + } else if (num_comp > 0) { + if (IBV_WC_SUCCESS != wc.status) { + tl_error(comm->lib, "mcast_poll_send: %s err %d num_comp", + ibv_wc_status_str(wc.status), num_comp); + return -1; + } + comm->pending_send -= num_comp; + } + + return num_comp; +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + int num_packets, const int zcopy) +{ + struct ibv_send_wr *swr = &comm->mcast.swr; + struct ibv_sge *ssg = &comm->mcast.ssg; + int max_per_packet = comm->max_per_packet; + int offset = req->offset, i; + struct ibv_send_wr *bad_wr; + struct pp_packet *pp; + int rc; + int length; + ucc_status_t status; + + for (i = 0; i < num_packets; i++) { + if (comm->params.sx_depth <= + (comm->pending_send * comm->params.scq_moderation + comm->tx)) { + if (ucc_tl_mlx5_mcast_poll_send(comm) < 0) { + return UCC_ERR_NO_MESSAGE; + } + break; + } + + if (NULL == (pp = ucc_tl_mlx5_mcast_buf_get_free(comm))) { + break; + } + + ucc_assert(pp->context == 0); + + __builtin_prefetch((void*) pp->buf); + __builtin_prefetch(PTR_OFFSET(req->ptr, offset)); + + length = req->to_send == 1 ? req->last_pkt_len : max_per_packet; + pp->length = length; + pp->psn = comm->psn; + ssg[0].addr = (uintptr_t) PTR_OFFSET(req->ptr, offset); + + if (zcopy) { + pp->context = (uintptr_t) PTR_OFFSET(req->ptr, offset); + } else { + memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length); + ssg[0].addr = (uint64_t) pp->buf; + } + + ssg[0].length = length; + ssg[0].lkey = req->mr->lkey; + swr[0].wr_id = MCAST_BCASTSEND_WR; + swr[0].imm_data = htonl(pp->psn); + swr[0].send_flags = (length <= comm->max_inline) ? IBV_SEND_INLINE : 0; + + comm->r_window[pp->psn & (comm->wsize-1)] = pp; + comm->psn++; + req->to_send--; + offset += length; + comm->tx++; + + if (comm->tx == comm->params.scq_moderation) { + swr[0].send_flags |= IBV_SEND_SIGNALED; + comm->tx = 0; + comm->pending_send++; + } + + tl_trace(comm->lib, "post_send, psn %d, length %d, zcopy %d, signaled %d", + pp->psn, pp->length, zcopy, swr[0].send_flags & + IBV_SEND_SIGNALED); + + if (0 != (rc = ibv_post_send(comm->mcast.qp, &swr[0], &bad_wr))) { + tl_error(comm->lib, "post send failed: ret %d, start_psn %d, to_send %d, " + "to_recv %d, length %d, psn %d, inline %d", + rc, req->start_psn, req->to_send, req->to_recv, + length, pp->psn, length <= comm->max_inline); + return UCC_ERR_NO_MESSAGE; + } + + status = ucc_tl_mlx5_mcast_check_nack_requests(comm, pp->psn); + if (UCC_OK != status) { + return status; + } + } + + req->offset = offset; + + return UCC_OK; +} + +/* this function return the number of mcast recv packets that + * have been arrived or -1 in case of error */ +static inline int ucc_tl_mlx5_mcast_recv(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + int num_left, int *pending_q_size) +{ + struct pp_packet *pp; + struct pp_packet *next; + uint64_t id; + struct ibv_wc *wc; + int num_comp; + int i; + int real_num_comp; + + /* check if we have already received something */ + ucc_list_for_each_safe(pp, next, &comm->pending_q, super) { + if (PSN_IS_IN_RANGE(pp->psn, req, comm)) { + __builtin_prefetch(req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm)); + __builtin_prefetch((void*) pp->buf); + ucc_list_del(&pp->super); + ucc_tl_mlx5_mcast_process_packet(comm, req, pp); + num_left--; + } else if (pp->psn < comm->last_acked){ + ucc_list_del(&pp->super); + ucc_list_add_tail(&comm->bpool, &pp->super); + } + + (*pending_q_size)++; + }; + + wc = ucc_malloc(sizeof(struct ibv_wc) * POLL_PACKED, "WC"); + if (!wc) { + tl_error(comm->lib, "ucc_malloc failed"); + return -1; + } + + while (num_left > 0) + { + memset(wc, 0, sizeof(struct ibv_wc) * POLL_PACKED); + num_comp = ibv_poll_cq(comm->rcq, POLL_PACKED, wc); + + if (num_comp < 0) { + tl_error(comm->lib, "recv queue poll completion failed %d", num_comp); + ucc_free(wc); + return -1; + } else if (num_comp == 0) { + break; + } + + real_num_comp = num_comp; + + for (i = 0; i < real_num_comp; i++) { + if (IBV_WC_SUCCESS != wc[i].status) { + tl_error(comm->lib, "mcast_recv: %s err pending_recv %d wr_id %ld" + " num_comp %d byte_len %d", + ibv_wc_status_str(wc[i].status), comm->pending_recv, + wc[i].wr_id, num_comp, wc[i].byte_len); + ucc_free(wc); + return -1; + } + + id = wc[i].wr_id; + pp = (struct pp_packet*) (id); + pp->length = wc[i].byte_len - GRH_LENGTH; + pp->psn = ntohl(wc[i].imm_data); + + tl_trace(comm->lib, "completion: psn %d, length %d, already_received %d, " + " psn in req %d, req_start %d, req_num packets" + " %d, to_send %d, to_recv %d, num_left %d", + pp->psn, pp->length, PSN_RECEIVED(pp->psn, + comm) > 0, PSN_IS_IN_RANGE(pp->psn, req, + comm), req->start_psn, req->num_packets, + req->to_send, req->to_recv, num_left); + + if (PSN_RECEIVED(pp->psn, comm) || pp->psn < comm->last_acked) { + /* This psn was already received */ + ucc_assert(pp->context == 0); + ucc_list_add_tail(&comm->bpool, &pp->super); + } else { + if (num_left > 0 && PSN_IS_IN_RANGE(pp->psn, req, comm)) { + __builtin_prefetch(req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm)); + __builtin_prefetch((void*) pp->buf); + ucc_tl_mlx5_mcast_process_packet(comm, req, pp); + num_left--; + } else { + ucc_list_add_tail(&comm->pending_q, &pp->super); + } + } + } + + comm->pending_recv -= num_comp; + ucc_tl_mlx5_mcast_post_recv_buffers(comm); + } + + ucc_free(wc); + return num_left; +} + +/* this function returns the number of mcast recv packets + * that have been completed or -1 in case of error */ +static inline int ucc_tl_mlx5_mcast_poll_recv(ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + struct pp_packet *pp; + struct ibv_wc wc; + int num_comp; + uint64_t id; + int length; + uint32_t psn; + + do { + num_comp = ibv_poll_cq(comm->rcq, 1, &wc); + + if (num_comp > 0) { + + if (IBV_WC_SUCCESS != wc.status) { + tl_error(comm->lib, "mcast_poll_recv: %s err %d num_comp", + ibv_wc_status_str(wc.status), num_comp); + return -1; + } + + // Make sure we received all in order. + id = wc.wr_id; + length = wc.byte_len - GRH_LENGTH; + psn = ntohl(wc.imm_data); + pp = (struct pp_packet*) id; + + if (psn >= comm->psn) { + ucc_assert(!PSN_RECEIVED(psn, comm)); + pp->psn = psn; + pp->length = length; + ucc_list_add_tail(&comm->pending_q, &pp->super); + } else { + ucc_assert(pp->context == 0); + ucc_list_add_tail(&comm->bpool, &pp->super); + } + + comm->pending_recv--; + ucc_tl_mlx5_mcast_post_recv_buffers(comm); + } else if (num_comp != 0) { + tl_error(comm->lib, "mcast_poll_recv: %d num_comp", num_comp); + return -1; + } + } while (num_comp); + + return num_comp; +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + if (comm->racks_n != comm->child_n || comm->sacks_n != comm->parent_n || + comm->nack_requests) { + if (comm->pending_send) { + if (ucc_tl_mlx5_mcast_poll_send(comm) < 0) { + return UCC_ERR_NO_MESSAGE; + } + } + + if (comm->parent_n) { + ucc_tl_mlx5_mcast_poll_recv(comm); + } + + ucc_tl_mlx5_mcast_check_nack_requests_all(comm); + } + + if (comm->parent_n && !comm->reliable_in_progress) { + ucc_tl_mlx5_mcast_reliable_send(comm); + } + + if (!comm->reliable_in_progress) { + comm->reliable_in_progress = 1; + } + + if (comm->racks_n == comm->child_n && comm->sacks_n == comm->parent_n && + 0 == comm->nack_requests) { + // Reset for next round. + memset(comm->parents, 0, sizeof(comm->parents)); + memset(comm->children, 0, sizeof(comm->children)); + + comm->racks_n = comm->child_n = 0; + comm->sacks_n = comm->parent_n = 0; + comm->reliable_in_progress = 0; + + return UCC_OK; + } + + return UCC_INPROGRESS; +} + +ucc_status_t ucc_tl_setup_mcast(ucc_tl_mlx5_mcast_coll_comm_t *comm); + +ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_coll_comm_t *comm); + +ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_coll_comm_t *comm); + +ucc_status_t ucc_tl_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm); + +#endif /* TL_MLX5_MCAST_HELPER_H_ */ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h new file mode 100644 index 0000000000..6211093084 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h @@ -0,0 +1,64 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_mcast.h" +#include "tl_mlx5_mcast_helper.h" + +#ifndef TL_MLX5_MCAST_PROGRESS_H_ +#define TL_MLX5_MCAST_PROGRESS_H_ + +#define TO_VIRTUAL(_rank, _size, _root) ((_rank + _size - _root) % _size) + +#define TO_ORIGINAL(_rank, _size, _root) ((_rank + _root) % _size) + +#define ACK 1 + +#define GET_COMPL_OBJ(_comm, _compl_fn, _pkt_id, _req) \ + ({ \ + void* item; \ + ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj; \ + item = ucc_mpool_get(&(_comm)->ctx->compl_objects_mp); \ + obj = (ucc_tl_mlx5_mcast_p2p_completion_obj_t *)item; \ + \ + obj->data[0] = (uintptr_t)_comm; \ + obj->compl_cb = _compl_fn; \ + obj->data[1] = (uintptr_t)_pkt_id; \ + obj->data[2] = (uintptr_t)_req; \ + obj; \ + }) + +#define GET_NACK_REQ(_comm, _pkt_id) \ + ({ \ + void* item; \ + ucc_tl_mlx5_mcast_nack_req_t *_req; \ + item = ucc_mpool_get(&(_comm)->ctx->nack_reqs_mp); \ + \ + _req = (ucc_tl_mlx5_mcast_nack_req_t *)item; \ + _req->comm = _comm; \ + _req->pkt_id = _pkt_id; \ + _req; \ + }) + +int ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + ucc_rank_t root); + +ucc_status_t ucc_tl_mlx5_mcast_bcast_check_drop(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req); + +void ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + struct pp_packet* pp); + +int ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t *comm, + uint32_t psn); + +int ucc_tl_mlx5_mcast_reliable_send(ucc_tl_mlx5_mcast_coll_comm_t* comm); + +int ucc_tl_mlx5_mcast_check_nack_requests_all(ucc_tl_mlx5_mcast_coll_comm_t* comm); + +#endif /* ifndef TL_MLX5_MCAST_PROGRESS_H_ */ + diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index d716551f67..31044fe8b3 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -13,7 +13,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_cont ucc_tl_mlx5_mcast_team_t **mcast_team, /* NOLINT */ ucc_tl_mlx5_mcast_context_t *ctx, /* NOLINT */ const ucc_base_team_params_t *params, /* NOLINT */ - mcast_coll_comm_init_spec_t *mcast_conf /* NOLINT */) + ucc_tl_mlx5_mcast_coll_comm_init_spec_t *mcast_conf /* NOLINT */) { return UCC_OK; } diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index c3697e1ddb..155e6144af 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -50,16 +50,16 @@ typedef struct ucc_tl_mlx5_ib_qp_conf { } ucc_tl_mlx5_ib_qp_conf_t; typedef struct ucc_tl_mlx5_lib_config { - ucc_tl_lib_config_t super; - int asr_barrier; - int block_size; - int num_dci_qps; - int dc_threshold; - size_t dm_buf_size; - unsigned long dm_buf_num; - int dm_host; - ucc_tl_mlx5_ib_qp_conf_t qp_conf; - mcast_coll_comm_init_spec_t mcast_conf; + ucc_tl_lib_config_t super; + int asr_barrier; + int block_size; + int num_dci_qps; + int dc_threshold; + size_t dm_buf_size; + unsigned long dm_buf_num; + int dm_host; + ucc_tl_mlx5_ib_qp_conf_t qp_conf; + ucc_tl_mlx5_mcast_coll_comm_init_spec_t mcast_conf; } ucc_tl_mlx5_lib_config_t; typedef struct ucc_tl_mlx5_context_config {