Skip to content

Commit

Permalink
TL/MLX5: rcache and p2p lib for mcast (#826)
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB authored Sep 11, 2023
1 parent 8661b7f commit 745474a
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 10 deletions.
14 changes: 9 additions & 5 deletions src/components/tl/mlx5/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ alltoall = \
alltoall/alltoall_inline.h \
alltoall/alltoall_coll.c

mcast = \
mcast/tl_mlx5_mcast_context.c \
mcast/tl_mlx5_mcast.h \
mcast/tl_mlx5_mcast_coll.c \
mcast/tl_mlx5_mcast_coll.h \
mcast = \
mcast/tl_mlx5_mcast_context.c \
mcast/tl_mlx5_mcast.h \
mcast/tl_mlx5_mcast_coll.c \
mcast/tl_mlx5_mcast_coll.h \
mcast/tl_mlx5_mcast_rcache.h \
mcast/tl_mlx5_mcast_rcache.c \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \
mcast/tl_mlx5_mcast_team.c

sources = \
Expand Down
141 changes: 141 additions & 0 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "ucc_tl_mlx5_mcast_p2p.h"

static inline void ucc_tl_mlx5_mcast_p2p_completion_cb(void* context)
{
ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj =
(ucc_tl_mlx5_mcast_p2p_completion_obj_t *)context;

ucc_assert(obj != NULL && obj->compl_cb != NULL);

obj->compl_cb(obj);

ucc_assert(obj->req != NULL);

ucc_collective_finalize(obj->req);
}

void ucc_tl_mlx5_mcast_completion_cb(void* context, ucc_status_t status) //NOLINT
{
ucc_tl_mlx5_mcast_p2p_completion_cb(context);
}

static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t
len, ucc_rank_t my_team_rank, ucc_rank_t dest,
ucc_team_h team, ucc_context_h ctx,
ucc_coll_callback_t *callback,
ucc_coll_req_h *p2p_req, int is_send)
{
ucc_status_t status = UCC_OK;
ucc_coll_req_h req = NULL;
ucc_coll_args_t args;

args.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET |
UCC_COLL_ARGS_FIELD_CB;
args.coll_type = UCC_COLL_TYPE_BCAST;
args.src.info.buffer = buf;
args.src.info.count = len;
args.src.info.datatype = UCC_DT_INT8;
args.src.info.mem_type = UCC_MEMORY_TYPE_HOST;
args.root = is_send ? my_team_rank : dest;
args.cb.cb = callback->cb;
args.cb.data = callback->data;
args.active_set.size = 2;
args.active_set.start = my_team_rank;
args.active_set.stride = dest - my_team_rank;

status = ucc_collective_init(&args, &req, team);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(ctx->lib, "nonblocking p2p init failed");
return status;
}

((ucc_tl_mlx5_mcast_p2p_completion_obj_t *)args.cb.data)->req = req;

status = ucc_collective_post(req);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(ctx->lib, "nonblocking p2p post failed");
return status;
}

*p2p_req = req;

return status;
}

static inline ucc_status_t do_send_nb(void *sbuf, size_t len, ucc_rank_t
my_team_rank, ucc_rank_t dest, ucc_team_h team,
ucc_context_h ctx, ucc_coll_callback_t
*callback, ucc_coll_req_h *req)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest,
team, ctx, callback, req, 1);
}

static inline ucc_status_t do_recv_nb(void *rbuf, size_t len, ucc_rank_t
my_team_rank, ucc_rank_t dest, ucc_team_h team,
ucc_context_h ctx, ucc_coll_callback_t
*callback, ucc_coll_req_h *req)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest,
team, ctx, callback, req, 0);
}

ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj)
{
ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx =
(ucc_tl_mlx5_mcast_oob_p2p_context_t *)context;
ucc_status_t status = UCC_OK;
ucc_coll_req_h req = NULL;
ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank;
ucc_team_h team = oob_p2p_ctx->base_team;
ucc_context_h ctx = oob_p2p_ctx->base_ctx;
ucc_coll_callback_t callback;

callback.cb = ucc_tl_mlx5_mcast_completion_cb;
callback.data = obj;

status = do_send_nb(src, size, my_team_rank, rank, team, ctx, &callback, &req);

if (status < 0) {
tl_error(ctx->lib, "nonblocking p2p send failed");
return status;
}

return status;
}

ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t
rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj)
{
ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx =
(ucc_tl_mlx5_mcast_oob_p2p_context_t *)context;
ucc_status_t status = UCC_OK;
ucc_coll_req_h req = NULL;
ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank;
ucc_team_h team = oob_p2p_ctx->base_team;
ucc_context_h ctx = oob_p2p_ctx->base_ctx;
ucc_coll_callback_t callback;

callback.cb = ucc_tl_mlx5_mcast_completion_cb;
callback.data = obj;

status = do_recv_nb(dst, size, my_team_rank, rank, team, ctx, &callback, &req);

if (status < 0) {
tl_error(ctx->lib, "nonblocking p2p recv failed");
return status;
}

return status;
}
18 changes: 18 additions & 0 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include <ucc/api/ucc.h>
#include "components/tl/mlx5/mcast/tl_mlx5_mcast.h"

ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);

ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* dst, size_t size, ucc_rank_t
rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);
82 changes: 78 additions & 4 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,39 @@

#define UCC_TL_MLX5_MCAST_ENABLE_BLOCKING true

struct ucc_tl_mlx5_mcast_p2p_completion_obj;
typedef int (*ucc_tl_mlx5_mcast_p2p_completion_cb_fn_t)(struct ucc_tl_mlx5_mcast_p2p_completion_obj *obj);
typedef struct ucc_tl_mlx5_mcast_p2p_completion_obj {
ucc_list_link_t super;
ucc_tl_mlx5_mcast_p2p_completion_cb_fn_t compl_cb;
uint64_t data[3];
ucc_coll_req_h req;
} ucc_tl_mlx5_mcast_p2p_completion_obj_t;

typedef struct mcast_coll_comm_init_spec {
} mcast_coll_comm_init_spec_t;

typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg);

typedef int (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);


typedef int (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);

typedef struct ucc_tl_mlx5_mcast_context_config {
ucc_tl_context_config_t super;
char *dev_list;
int use_rcache;
size_t reg_threshold;
unsigned int rand_seed;
unsigned int uprogress_num_polls;
int context_per_team;
} ucc_tl_mlx5_mcast_context_config_t;

typedef struct ucc_tl_mlx5_mcast_lib {
} ucc_tl_mlx5_mcast_lib_t;
UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *,
Expand All @@ -31,12 +61,49 @@ UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *,
typedef struct ucc_tl_mlx5_mcast_ctx_params {
} ucc_tl_mlx5_mcast_ctx_params_t;

typedef struct mcast_coll_context_t {
} mcast_coll_context_t;

typedef struct ucc_tl_mlx5_mcast_context_t {
typedef struct ucc_tl_mlx5_mcast_coll_context {
struct ibv_context *ctx;
struct ibv_pd *pd;
char *devname;
int max_qp_wr;
int ib_port;
int pkey_index;
int mtu;
struct rdma_cm_id *id;
struct rdma_event_channel *channel;
ucc_mpool_t compl_objects_mp;
ucc_mpool_t nack_reqs_mp;
ucc_list_link_t pending_nacks_list;
ucc_rcache_t *rcache;
ucc_tl_mlx5_mcast_ctx_params_t params;
ucc_base_lib_t *lib;
} ucc_tl_mlx5_mcast_coll_context_t;

typedef struct ucc_tl_mlx5_mcast_oob_ctx {
void *ctx;
union {
ucc_oob_coll_t *oob;
ucc_subset_t subset;
};
} ucc_tl_mlx5_mcast_oob_ctx_t;

typedef struct ucc_tl_mlx5_mcast_context {
ucc_thread_mode_t tm;
ucc_tl_mlx5_mcast_coll_context_t mcast_context;
ucc_tl_mlx5_mcast_context_config_t cfg;
ucc_mpool_t req_mp;
ucc_tl_mlx5_mcast_oob_ctx_t oob_ctx;
} ucc_tl_mlx5_mcast_context_t;

typedef struct ucc_tl_mlx5_mcast_reg {
void *mr;
} ucc_tl_mlx5_mcast_reg_t;

typedef struct ucc_tl_mlx5_mcast_rcache_region {
ucc_rcache_region_t super;
ucc_tl_mlx5_mcast_reg_t reg;
} ucc_tl_mlx5_mcast_rcache_region_t;


typedef struct mcast_coll_comm { /* Stuff at a per-communicator sort of level */
} mcast_coll_comm_t;
Expand All @@ -48,6 +115,13 @@ typedef struct ucc_tl_mlx5_mcast_team {
typedef struct ucc_tl_mlx5_mcast_coll_req { /* Stuff that has to happen per call */
} ucc_tl_mlx5_mcast_coll_req_t;

typedef struct ucc_tl_mlx5_mcast_oob_p2p_context {
ucc_context_h base_ctx;
ucc_team_h base_team;
ucc_rank_t my_team_rank;
ucc_subset_t subset;
} ucc_tl_mlx5_mcast_oob_p2p_context_t;

#define TASK_TEAM_MCAST(_task) \
(ucc_derived_of((_task)->super.team, ucc_tl_mlx5_mcast_team_t))
#define TASK_CTX_MCAST(_task) \
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task)

status = mcast_coll_do_bcast(buf, data_size, root, NULL, comm,
UCC_TL_MLX5_MCAST_ENABLE_BLOCKING, &task->bcast_mcast.req_handle);
if (UCC_OK != status && UCC_INPROGRESS != status) {
if (status < 0) {
tl_error(UCC_TASK_LIB(task), "mcast_coll_do_bcast failed:%d", status);
coll_task->status = status;
return ucc_task_complete(coll_task);
Expand Down
Loading

0 comments on commit 745474a

Please sign in to comment.