Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/MLX5: multicast design: adding helper.h, progress.h, and mcast.h #838

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/components/tl/mlx5/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ mcast = \
mcast/tl_mlx5_mcast_rcache.c \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \
mcast/tl_mlx5_mcast_progress.h \
mcast/tl_mlx5_mcast_helper.h \
mcast/tl_mlx5_mcast_team.c

sources = \
Expand Down
294 changes: 255 additions & 39 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,40 @@
#include "components/tl/ucc_tl_log.h"
#include "utils/ucc_rcache.h"

#define POLL_PACKED 16
#define REL_DONE ((void*)-1)
#define NB_POLL 8
#define NB_POLL_LARGE 32
#define MULTICAST_QPN 0xFFFFFF
/* default parameters during modify QP */
#define DEF_QKEY 0x1a1a1a1a
#define DEF_PKEY 0xffff
#define DEF_PSN 0
#define DEF_SL 0
#define DEF_SRC_PATH_BITS 0
#define GRH_LENGTH 40
#define DROP_THRESHOLD 1000000
#define MAX_COMM_POW2 32

#define UCC_TL_MLX5_MCAST_ENABLE_BLOCKING true
enum {
MCAST_PROTO_EAGER, /* Internal staging buffers */
MCAST_PROTO_ZCOPY
};

enum {
MCAST_P2P_NACK,
MCAST_P2P_ACK,
MCAST_P2P_NEED_NACK_SEND
};

enum {
MCAST_RECV_WR = 1,
MCAST_WAIT_RECV_WR,
MCAST_SEND_WR,
MCAST_CALC_WR,
MCAST_BCASTRECV_WR,
MCAST_BCASTSEND_WR,
};

struct ucc_tl_mlx5_mcast_p2p_completion_obj;
typedef int (*ucc_tl_mlx5_mcast_p2p_completion_cb_fn_t)(struct ucc_tl_mlx5_mcast_p2p_completion_obj *obj);
Expand All @@ -29,9 +61,6 @@ typedef struct ucc_tl_mlx5_mcast_p2p_completion_obj {
ucc_coll_req_h req;
} ucc_tl_mlx5_mcast_p2p_completion_obj_t;

typedef struct mcast_coll_comm_init_spec {
} mcast_coll_comm_init_spec_t;

typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg);

typedef int (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size,
Expand All @@ -43,6 +72,25 @@ typedef int (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);

typedef struct ucc_tl_mlx5_mcast_p2p_interface {
ucc_tl_mlx5_mcast_p2p_send_nb_fn_t send_nb;
ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t recv_nb;
} ucc_tl_mlx5_mcast_p2p_interface_t;

typedef struct mcast_coll_comm_init_spec {
ucc_tl_mlx5_mcast_p2p_interface_t p2p_iface;
int sx_depth;
int rx_depth;
int sx_sge;
int rx_sge;
int sx_inline;
int post_recv_thresh;
int scq_moderation;
int wsize;
int max_eager;
void *oob;
} ucc_tl_mlx5_mcast_coll_comm_init_spec_t;

typedef struct ucc_tl_mlx5_mcast_context_config {
ucc_tl_context_config_t super;
char *dev_list;
Expand All @@ -53,12 +101,27 @@ typedef struct ucc_tl_mlx5_mcast_context_config {
int context_per_team;
} ucc_tl_mlx5_mcast_context_config_t;

typedef struct ucc_tl_mlx5_mcast_lib {
} ucc_tl_mlx5_mcast_lib_t;
UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *,
const ucc_base_config_t *);
typedef struct ucc_tl_mlx5_mcast_oob_ctx {
void *ctx;
union {
ucc_oob_coll_t *oob;
ucc_subset_t subset;
};
} ucc_tl_mlx5_mcast_oob_ctx_t;

typedef struct ucc_tl_mlx5_mcast_reg {
void *mr;
} ucc_tl_mlx5_mcast_reg_t;

typedef struct ucc_tl_mlx5_mcast_rcache_region {
ucc_rcache_region_t super;
ucc_tl_mlx5_mcast_reg_t reg;
} ucc_tl_mlx5_mcast_rcache_region_t;

typedef struct ucc_tl_mlx5_mcast_ctx_params {
char *ib_dev_name;
int print_nack_stats;
int timeout;
} ucc_tl_mlx5_mcast_ctx_params_t;

typedef struct ucc_tl_mlx5_mcast_coll_context {
Expand All @@ -79,14 +142,6 @@ typedef struct ucc_tl_mlx5_mcast_coll_context {
ucc_base_lib_t *lib;
} ucc_tl_mlx5_mcast_coll_context_t;

typedef struct ucc_tl_mlx5_mcast_oob_ctx {
void *ctx;
union {
ucc_oob_coll_t *oob;
ucc_subset_t subset;
};
} ucc_tl_mlx5_mcast_oob_ctx_t;

typedef struct ucc_tl_mlx5_mcast_context {
ucc_thread_mode_t tm;
ucc_tl_mlx5_mcast_coll_context_t mcast_context;
Expand All @@ -95,24 +150,145 @@ typedef struct ucc_tl_mlx5_mcast_context {
ucc_tl_mlx5_mcast_oob_ctx_t oob_ctx;
} ucc_tl_mlx5_mcast_context_t;

typedef struct ucc_tl_mlx5_mcast_reg {
void *mr;
} ucc_tl_mlx5_mcast_reg_t;
struct pp_packet {
ucc_list_link_t super;
uint32_t psn;
int length;
uintptr_t context;
uintptr_t buf;
};

typedef struct ucc_tl_mlx5_mcast_rcache_region {
ucc_rcache_region_t super;
ucc_tl_mlx5_mcast_reg_t reg;
} ucc_tl_mlx5_mcast_rcache_region_t;
struct mcast_ctx {
struct ibv_qp *qp;
struct ibv_ah *ah;
struct ibv_send_wr swr;
struct ibv_sge ssg;
};

struct packet {
int type;
ucc_rank_t from;
uint32_t psn;
int comm_id;
};

typedef struct mcast_coll_comm { /* Stuff at a per-communicator sort of level */
} mcast_coll_comm_t;
typedef struct ucc_tl_mlx5_mcast_coll_comm {
struct pp_packet dummy_packet;
ucc_tl_mlx5_mcast_coll_context_t *ctx;
ucc_tl_mlx5_mcast_coll_comm_init_spec_t params;
ucc_tl_mlx5_mcast_p2p_interface_t p2p;
int tx;
struct ibv_cq *scq;
struct ibv_cq *rcq;
ucc_rank_t rank;
ucc_rank_t commsize;
char *grh_buf;
struct ibv_mr *grh_mr;
uint16_t mcast_lid;
union ibv_gid mgid;
unsigned max_inline;
size_t max_eager;
int max_per_packet;
int pending_send;
int pending_recv;
struct ibv_mr *pp_mr;
char *pp_buf;
struct pp_packet *pp;
uint32_t psn;
uint32_t last_psn;
uint32_t racks_n;
uint32_t sacks_n;
uint32_t last_acked;
uint32_t naks_n;
uint32_t child_n;
uint32_t parent_n;
int buf_n;
struct packet p2p_pkt[MAX_COMM_POW2];
struct packet p2p_spkt[MAX_COMM_POW2];
ucc_list_link_t bpool;
ucc_list_link_t pending_q;
struct mcast_ctx mcast;
int reliable_in_progress;
int recv_drop_packet_in_progress;
struct ibv_recv_wr *call_rwr;
struct ibv_sge *call_rsgs;
uint64_t timer;
int stalled;
int comm_id;
void *p2p_ctx;
ucc_base_lib_t *lib;
struct sockaddr_in6 mcast_addr;
int parents[MAX_COMM_POW2];
int children[MAX_COMM_POW2];
int nack_requests;
int nacks_counter;
int n_prep_reliable;
int n_mcast_reliable;
int wsize;
struct pp_packet *r_window[1]; // do not add any new variable after here
} ucc_tl_mlx5_mcast_coll_comm_t;

typedef struct ucc_tl_mlx5_mcast_team {
void *mcast_comm;
ucc_tl_mlx5_mcast_context_t *mcast_context;
struct ucc_tl_mlx5_mcast_coll_comm *mcast_comm;
ucc_tl_mlx5_mcast_oob_ctx_t oob_ctx;
} ucc_tl_mlx5_mcast_team_t;

typedef struct ucc_tl_mlx5_mcast_coll_req { /* Stuff that has to happen per call */

typedef struct ucc_tl_mlx5_mcast_nack_req {
ucc_list_link_t super;
int pkt_id;
ucc_tl_mlx5_mcast_coll_comm_t *comm;
} ucc_tl_mlx5_mcast_nack_req_t;

#define PSN_IS_IN_RANGE(_psn, _call, _comm) \
( \
((_psn >= _call->start_psn) && \
(_psn < _call->start_psn + _call->num_packets) && \
(_psn >= _comm->last_acked) && \
(_psn < _comm->last_acked + _comm->wsize)) \
)

#define PSN_TO_RECV_OFFSET(_psn, _call, _comm) \
( \
((ptrdiff_t)((_psn - _call->start_psn) \
* (_comm->max_per_packet))) \
)

#define PSN_TO_RECV_LEN(_psn, _call, _comm) \
( \
((_psn - _call->start_psn + 1) % \
_call->num_packets == 0 ? _call->last_pkt_len : \
_comm->max_per_packet) \
)

#define PSN_RECEIVED(_psn, _comm) \
( \
(_comm->r_window[(_psn) % \
_comm->wsize]->psn == (_psn)) \
)

typedef struct ucc_tl_mlx5_mcast_coll_req {
ucc_tl_mlx5_mcast_coll_comm_t *comm;
size_t length; /* bcast buffer size */
int proto;
struct ibv_mr *mr;
struct ibv_recv_wr *rwr;
struct ibv_sge *rsgs;
void *rreg;
char *ptr;
int am_root;
ucc_rank_t root;
void **rbufs;
int first_send_psn;
int to_send;
int to_recv;
ucc_rank_t parent;
uint32_t start_psn;
int num_packets;
int last_pkt_len;
int offset;
ucc_memory_type_t buf_mem_type;
} ucc_tl_mlx5_mcast_coll_req_t;
MamziB marked this conversation as resolved.
Show resolved Hide resolved

typedef struct ucc_tl_mlx5_mcast_oob_p2p_context {
Expand All @@ -122,19 +298,59 @@ typedef struct ucc_tl_mlx5_mcast_oob_p2p_context {
ucc_subset_t subset;
} ucc_tl_mlx5_mcast_oob_p2p_context_t;

#define TASK_TEAM_MCAST(_task) \
(ucc_derived_of((_task)->super.team, ucc_tl_mlx5_mcast_team_t))
#define TASK_CTX_MCAST(_task) \
(ucc_derived_of((_task)->super.team->context, ucc_tl_mlx5_mcast_context_t))
#define TASK_LIB_MCAST(_task) \
(ucc_derived_of((_task)->super.team->context->lib, ucc_tl_mlx5_mcast_lib_t))
#define TASK_ARGS_MCAST(_task) (_task)->super.bargs.args

ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context,
ucc_tl_mlx5_mcast_team_t **mcast_team,
ucc_tl_mlx5_mcast_context_t *ctx,
static inline struct pp_packet* ucc_tl_mlx5_mcast_buf_get_free(ucc_tl_mlx5_mcast_coll_comm_t* comm)
{
struct pp_packet* pp;

pp = ucc_list_extract_head(&comm->bpool, struct pp_packet, super);
MamziB marked this conversation as resolved.
Show resolved Hide resolved

ucc_assert(pp == NULL || pp->context == 0);

return pp;
}

static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast_coll_comm_t* comm)
{
struct ibv_recv_wr *bad_wr = NULL;
struct ibv_recv_wr *rwr = comm->call_rwr;
struct ibv_sge *sge = comm->call_rsgs;
struct pp_packet *pp = NULL;
int count = comm->params.rx_depth - comm->pending_recv;
MamziB marked this conversation as resolved.
Show resolved Hide resolved
int i;

if (count <= comm->params.post_recv_thresh) {
return UCC_OK;
}

for (i = 0; i < count - 1; i++) {
if (NULL == (pp = ucc_tl_mlx5_mcast_buf_get_free(comm))) {
MamziB marked this conversation as resolved.
Show resolved Hide resolved
break;
}

rwr[i].wr_id = ((uint64_t) pp);
rwr[i].next = &rwr[i+1];
MamziB marked this conversation as resolved.
Show resolved Hide resolved
sge[2*i + 1].addr = pp->buf;

ucc_assert((uint64_t)comm->pp <= rwr[i].wr_id
&& ((uint64_t)comm->pp + comm->buf_n * sizeof(struct pp_packet)) > rwr[i].wr_id);
}
if (i != 0) {
rwr[i-1].next = NULL;
if (ibv_post_recv(comm->mcast.qp, &rwr[0], &bad_wr)) {
tl_error(comm->lib, "failed to prepost recvs: errno %d", errno);
return UCC_ERR_NO_RESOURCE;
MamziB marked this conversation as resolved.
Show resolved Hide resolved
}
comm->pending_recv += i;
}

return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context,
ucc_tl_mlx5_mcast_team_t **mcast_team,
ucc_tl_mlx5_mcast_context_t *ctx,
const ucc_base_team_params_t *params,
mcast_coll_comm_init_spec_t *mcast_conf);
ucc_tl_mlx5_mcast_coll_comm_init_spec_t *mcast_conf);

ucc_status_t ucc_tl_mlx5_mcast_coll_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
Expand Down
Loading
Loading