Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/MLX5: rcache and p2p lib for mcast #826

Merged
merged 1 commit into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/components/tl/mlx5/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ alltoall = \
alltoall/alltoall_inline.h \
alltoall/alltoall_coll.c

mcast = \
mcast/tl_mlx5_mcast_context.c \
mcast/tl_mlx5_mcast.h \
mcast/tl_mlx5_mcast_coll.c \
mcast/tl_mlx5_mcast_coll.h \
mcast = \
mcast/tl_mlx5_mcast_context.c \
mcast/tl_mlx5_mcast.h \
mcast/tl_mlx5_mcast_coll.c \
mcast/tl_mlx5_mcast_coll.h \
mcast/tl_mlx5_mcast_rcache.h \
mcast/tl_mlx5_mcast_rcache.c \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \
mcast/tl_mlx5_mcast_team.c

sources = \
Expand Down
141 changes: 141 additions & 0 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "ucc_tl_mlx5_mcast_p2p.h"

static inline void ucc_tl_mlx5_mcast_p2p_completion_cb(void* context)
{
ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj =
(ucc_tl_mlx5_mcast_p2p_completion_obj_t *)context;

ucc_assert(obj != NULL && obj->compl_cb != NULL);

obj->compl_cb(obj);

ucc_assert(obj->req != NULL);

ucc_collective_finalize(obj->req);
}

void ucc_tl_mlx5_mcast_completion_cb(void* context, ucc_status_t status) //NOLINT
{
ucc_tl_mlx5_mcast_p2p_completion_cb(context);
}

static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t
len, ucc_rank_t my_team_rank, ucc_rank_t dest,
ucc_team_h team, ucc_context_h ctx,
ucc_coll_callback_t *callback,
ucc_coll_req_h *p2p_req, int is_send)
{
ucc_status_t status = UCC_OK;
ucc_coll_req_h req = NULL;
ucc_coll_args_t args;

args.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET |
UCC_COLL_ARGS_FIELD_CB;
args.coll_type = UCC_COLL_TYPE_BCAST;
args.src.info.buffer = buf;
args.src.info.count = len;
args.src.info.datatype = UCC_DT_INT8;
args.src.info.mem_type = UCC_MEMORY_TYPE_HOST;
args.root = is_send ? my_team_rank : dest;
args.cb.cb = callback->cb;
args.cb.data = callback->data;
args.active_set.size = 2;
args.active_set.start = my_team_rank;
args.active_set.stride = dest - my_team_rank;

status = ucc_collective_init(&args, &req, team);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(ctx->lib, "nonblocking p2p init failed");
return status;
}

((ucc_tl_mlx5_mcast_p2p_completion_obj_t *)args.cb.data)->req = req;

status = ucc_collective_post(req);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(ctx->lib, "nonblocking p2p post failed");
return status;
}

*p2p_req = req;

return status;
}

static inline ucc_status_t do_send_nb(void *sbuf, size_t len, ucc_rank_t
my_team_rank, ucc_rank_t dest, ucc_team_h team,
ucc_context_h ctx, ucc_coll_callback_t
*callback, ucc_coll_req_h *req)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest,
team, ctx, callback, req, 1);
}

static inline ucc_status_t do_recv_nb(void *rbuf, size_t len, ucc_rank_t
my_team_rank, ucc_rank_t dest, ucc_team_h team,
ucc_context_h ctx, ucc_coll_callback_t
*callback, ucc_coll_req_h *req)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest,
team, ctx, callback, req, 0);
}

ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj)
{
ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx =
(ucc_tl_mlx5_mcast_oob_p2p_context_t *)context;
ucc_status_t status = UCC_OK;
ucc_coll_req_h req = NULL;
ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank;
ucc_team_h team = oob_p2p_ctx->base_team;
ucc_context_h ctx = oob_p2p_ctx->base_ctx;
ucc_coll_callback_t callback;

callback.cb = ucc_tl_mlx5_mcast_completion_cb;
callback.data = obj;

status = do_send_nb(src, size, my_team_rank, rank, team, ctx, &callback, &req);

if (status < 0) {
tl_error(ctx->lib, "nonblocking p2p send failed");
return status;
}

return status;
}

ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t
rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj)
{
ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx =
(ucc_tl_mlx5_mcast_oob_p2p_context_t *)context;
ucc_status_t status = UCC_OK;
ucc_coll_req_h req = NULL;
ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank;
ucc_team_h team = oob_p2p_ctx->base_team;
ucc_context_h ctx = oob_p2p_ctx->base_ctx;
ucc_coll_callback_t callback;

callback.cb = ucc_tl_mlx5_mcast_completion_cb;
callback.data = obj;

status = do_recv_nb(dst, size, my_team_rank, rank, team, ctx, &callback, &req);

if (status < 0) {
tl_error(ctx->lib, "nonblocking p2p recv failed");
return status;
}

return status;
}
18 changes: 18 additions & 0 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include <ucc/api/ucc.h>
#include "components/tl/mlx5/mcast/tl_mlx5_mcast.h"

ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);

ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* dst, size_t size, ucc_rank_t
rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);
82 changes: 78 additions & 4 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,39 @@

#define UCC_TL_MLX5_MCAST_ENABLE_BLOCKING true

struct ucc_tl_mlx5_mcast_p2p_completion_obj;
typedef int (*ucc_tl_mlx5_mcast_p2p_completion_cb_fn_t)(struct ucc_tl_mlx5_mcast_p2p_completion_obj *obj);
typedef struct ucc_tl_mlx5_mcast_p2p_completion_obj {
ucc_list_link_t super;
ucc_tl_mlx5_mcast_p2p_completion_cb_fn_t compl_cb;
uint64_t data[3];
ucc_coll_req_h req;
} ucc_tl_mlx5_mcast_p2p_completion_obj_t;

typedef struct mcast_coll_comm_init_spec {
} mcast_coll_comm_init_spec_t;

typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg);

typedef int (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);


typedef int (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);

typedef struct ucc_tl_mlx5_mcast_context_config {
ucc_tl_context_config_t super;
char *dev_list;
int use_rcache;
size_t reg_threshold;
unsigned int rand_seed;
unsigned int uprogress_num_polls;
int context_per_team;
} ucc_tl_mlx5_mcast_context_config_t;

typedef struct ucc_tl_mlx5_mcast_lib {
} ucc_tl_mlx5_mcast_lib_t;
UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *,
Expand All @@ -31,12 +61,49 @@ UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *,
typedef struct ucc_tl_mlx5_mcast_ctx_params {
} ucc_tl_mlx5_mcast_ctx_params_t;

typedef struct mcast_coll_context_t {
} mcast_coll_context_t;

typedef struct ucc_tl_mlx5_mcast_context_t {
typedef struct ucc_tl_mlx5_mcast_coll_context {
struct ibv_context *ctx;
struct ibv_pd *pd;
char *devname;
int max_qp_wr;
int ib_port;
int pkey_index;
int mtu;
struct rdma_cm_id *id;
struct rdma_event_channel *channel;
ucc_mpool_t compl_objects_mp;
ucc_mpool_t nack_reqs_mp;
ucc_list_link_t pending_nacks_list;
ucc_rcache_t *rcache;
ucc_tl_mlx5_mcast_ctx_params_t params;
ucc_base_lib_t *lib;
} ucc_tl_mlx5_mcast_coll_context_t;

typedef struct ucc_tl_mlx5_mcast_oob_ctx {
void *ctx;
union {
ucc_oob_coll_t *oob;
ucc_subset_t subset;
};
} ucc_tl_mlx5_mcast_oob_ctx_t;

typedef struct ucc_tl_mlx5_mcast_context {
ucc_thread_mode_t tm;
ucc_tl_mlx5_mcast_coll_context_t mcast_context;
ucc_tl_mlx5_mcast_context_config_t cfg;
ucc_mpool_t req_mp;
ucc_tl_mlx5_mcast_oob_ctx_t oob_ctx;
} ucc_tl_mlx5_mcast_context_t;

typedef struct ucc_tl_mlx5_mcast_reg {
void *mr;
} ucc_tl_mlx5_mcast_reg_t;

typedef struct ucc_tl_mlx5_mcast_rcache_region {
ucc_rcache_region_t super;
ucc_tl_mlx5_mcast_reg_t reg;
} ucc_tl_mlx5_mcast_rcache_region_t;


typedef struct mcast_coll_comm { /* Stuff at a per-communicator sort of level */
} mcast_coll_comm_t;
Expand All @@ -48,6 +115,13 @@ typedef struct ucc_tl_mlx5_mcast_team {
typedef struct ucc_tl_mlx5_mcast_coll_req { /* Stuff that has to happen per call */
} ucc_tl_mlx5_mcast_coll_req_t;

typedef struct ucc_tl_mlx5_mcast_oob_p2p_context {
ucc_context_h base_ctx;
ucc_team_h base_team;
ucc_rank_t my_team_rank;
ucc_subset_t subset;
} ucc_tl_mlx5_mcast_oob_p2p_context_t;

#define TASK_TEAM_MCAST(_task) \
(ucc_derived_of((_task)->super.team, ucc_tl_mlx5_mcast_team_t))
#define TASK_CTX_MCAST(_task) \
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task)

status = mcast_coll_do_bcast(buf, data_size, root, NULL, comm,
UCC_TL_MLX5_MCAST_ENABLE_BLOCKING, &task->bcast_mcast.req_handle);
if (UCC_OK != status && UCC_INPROGRESS != status) {
if (status < 0) {
tl_error(UCC_TASK_LIB(task), "mcast_coll_do_bcast failed:%d", status);
coll_task->status = status;
return ucc_task_complete(coll_task);
Expand Down
Loading
Loading