Skip to content

Commit

Permalink
address reviewers comments PR 826 second set
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Aug 28, 2023
1 parent be2ad7f commit 26c4912
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 61 deletions.
18 changes: 9 additions & 9 deletions src/components/tl/mlx5/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ alltoall = \
alltoall/alltoall_inline.h \
alltoall/alltoall_coll.c

mcast = \
mcast/tl_mlx5_mcast_context.c \
mcast/tl_mlx5_mcast.h \
mcast/tl_mlx5_mcast_coll.c \
mcast/tl_mlx5_mcast_coll.h \
mcast/tl_mlx5_mcast_rcache.h \
mcast/tl_mlx5_mcast_rcache.c \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \
mcast = \
mcast/tl_mlx5_mcast_context.c \
mcast/tl_mlx5_mcast.h \
mcast/tl_mlx5_mcast_coll.c \
mcast/tl_mlx5_mcast_coll.h \
mcast/tl_mlx5_mcast_rcache.h \
mcast/tl_mlx5_mcast_rcache.c \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \
mcast/tl_mlx5_mcast_team.c

sources = \
Expand Down
54 changes: 35 additions & 19 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

static inline void ucc_tl_mlx5_mcast_p2p_completion_cb(void* context)
{
ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj = (ucc_tl_mlx5_mcast_p2p_completion_obj_t *)context;
ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj =
(ucc_tl_mlx5_mcast_p2p_completion_obj_t *)context;

ucc_assert(obj != NULL && obj->compl_cb != NULL);

Expand All @@ -24,16 +25,23 @@ void ucc_tl_mlx5_mcast_completion_cb(void* context, ucc_status_t status)
ucc_tl_mlx5_mcast_p2p_completion_cb(context);
}

static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t len, ucc_rank_t my_team_rank, int dest,
ucc_team_h team, ucc_context_h ctx,
ucc_coll_callback_t *callback,
ucc_coll_req_h *p2p_req, int is_send)
static inline ucc_status_t
ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf,
size_t len,
ucc_rank_t my_team_rank,
int dest,
ucc_team_h team,
ucc_context_h ctx,
ucc_coll_callback_t *callback,
ucc_coll_req_h *p2p_req,
int is_send)
{
ucc_status_t status = UCC_OK;
ucc_coll_req_h req = NULL;
ucc_coll_args_t args;

args.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_CB;
args.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET |
UCC_COLL_ARGS_FIELD_CB;
args.coll_type = UCC_COLL_TYPE_BCAST;
args.src.info.buffer = buf;
args.src.info.count = len;
Expand Down Expand Up @@ -65,17 +73,19 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t l
return status;
}

static inline ucc_status_t do_send_nb(void *sbuf, size_t len, ucc_rank_t my_team_rank,
int dest, ucc_team_h team, ucc_context_h ctx,
ucc_coll_callback_t *callback, ucc_coll_req_h *req)
static inline ucc_status_t do_send_nb(void *sbuf, size_t len, ucc_rank_t
my_team_rank, int dest, ucc_team_h team,
ucc_context_h ctx, ucc_coll_callback_t
*callback, ucc_coll_req_h *req)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest,
team, ctx, callback, req, 1);
}

static inline ucc_status_t do_recv_nb(void *rbuf, size_t len, ucc_rank_t my_team_rank,
int dest, ucc_team_h team, ucc_context_h ctx,
ucc_coll_callback_t *callback, ucc_coll_req_h *req)
static inline ucc_status_t do_recv_nb(void *rbuf, size_t len, ucc_rank_t
my_team_rank, int dest, ucc_team_h team,
ucc_context_h ctx, ucc_coll_callback_t
*callback, ucc_coll_req_h *req)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest,
team, ctx, callback, req, 0);
Expand All @@ -99,10 +109,13 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_progress(void *context)
return status;
}

ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, int rank, int tag,
void *context, ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj)
ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
rank, int tag, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj)
{
ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx = (ucc_tl_mlx5_mcast_oob_p2p_context_t *)context;
ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx =
(ucc_tl_mlx5_mcast_oob_p2p_context_t *)context;
ucc_status_t status = UCC_OK;
ucc_coll_req_h req = NULL;
ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank;
Expand All @@ -123,10 +136,13 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, int rank, int
return status;
}

ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* src, size_t size, int rank, int tag,
void *context, ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj)
ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t
rank, int tag, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj)
{
ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx = (ucc_tl_mlx5_mcast_oob_p2p_context_t *)context;
ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx =
(ucc_tl_mlx5_mcast_oob_p2p_context_t *)context;
ucc_status_t status = UCC_OK;
ucc_coll_req_h req = NULL;
ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank;
Expand All @@ -137,7 +153,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* src, size_t size, int rank, int
callback.cb = ucc_tl_mlx5_mcast_completion_cb;
callback.data = obj;

status = do_recv_nb(src, size, my_team_rank, rank, team, ctx, &callback, &req);
status = do_recv_nb(dst, size, my_team_rank, rank, team, ctx, &callback, &req);

if (status != UCC_INPROGRESS && status != UCC_OK) {
tl_error(ctx->lib, "nonblocking p2p recv failed");
Expand Down
12 changes: 8 additions & 4 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
#include <ucc/api/ucc.h>
#include "components/tl/mlx5/mcast/tl_mlx5_mcast.h"

ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, int rank, int tag, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj);
ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
rank, int tag, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);

ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* src, size_t size, int rank, int tag, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj);
ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* dst, size_t size, ucc_rank_t
rank, int tag, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);

ucc_status_t ucc_tl_mlx5_mcast_p2p_progress(void* ctx);
14 changes: 8 additions & 6 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,22 @@ typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg);

typedef int (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size,
int rank, int tag, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t * compl_obj);
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);


typedef int (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size,
int rank, int tag, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t * compl_obj);
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);

typedef int (*ucc_tl_mlx5_mcast_p2p_send_fn_t)(void* src, size_t size,
int rank, int tag, void *context,
ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t wait_cb, void *wait_arg);
ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t wait_cb,
void *wait_arg);

typedef int (*ucc_tl_mlx5_mcast_p2p_recv_fn_t)(void* src, size_t size,
int rank, int tag, void *context,
ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t wait_cb, void *wait_arg);
ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t wait_cb,
void *wait_arg);

typedef int (*ucc_tl_mlx5_mcast_p2p_progress_fn_t)(void*);

Expand All @@ -71,7 +73,7 @@ UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *,
typedef struct ucc_tl_mlx5_mcast_ctx_params {
} ucc_tl_mlx5_mcast_ctx_params_t;

typedef struct ucc_tl_mlx5_mcast_coll_context_t {
typedef struct ucc_tl_mlx5_mcast_coll_context {
struct ibv_context *ctx;
struct ibv_pd *pd;
char *devname;
Expand All @@ -97,7 +99,7 @@ typedef struct ucc_tl_mlx5_mcast_oob_ctx {
};
} ucc_tl_mlx5_mcast_oob_ctx_t;

typedef struct ucc_tl_mlx5_mcast_context_t {
typedef struct ucc_tl_mlx5_mcast_context {
ucc_thread_mode_t tm;
ucc_tl_mlx5_mcast_coll_context_t mcast_context;
ucc_tl_mlx5_mcast_context_config_t cfg;
Expand Down
49 changes: 28 additions & 21 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,31 @@

#include "tl_mlx5_mcast_rcache.h"

static ucs_status_t ucc_tl_mlx5_mcast_coll_reg_mr(ucc_tl_mlx5_mcast_coll_context_t *ctx,
void *data, size_t data_size, void **mr)
static ucs_status_t ucc_tl_mlx5_mcast_coll_reg_mr(ucc_tl_mlx5_mcast_coll_context_t
*ctx, void *data, size_t data_size,
void **mr)
{
*mr = ibv_reg_mr(ctx->pd, data, data_size, IBV_ACCESS_LOCAL_WRITE |
IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE);

tl_debug(ctx->lib, "external memory register: ptr %p, len %zd, mr %p",
tl_trace(ctx->lib, "external memory register: ptr %p, len %zd, mr %p",
data, data_size, (*mr));
if (!*mr) {
tl_error(ctx->lib, "Failed to register MR\n");
tl_error(ctx->lib, "failed to register MR\n");
return UCS_ERR_NO_MEMORY;
}

return UCS_OK;
}


static ucc_status_t ucc_tl_mlx5_mcast_coll_dereg_mr(ucc_tl_mlx5_mcast_coll_context_t *ctx, void *mr)
static ucc_status_t ucc_tl_mlx5_mcast_coll_dereg_mr(ucc_tl_mlx5_mcast_coll_context_t
*ctx, void *mr)
{
if(mr != NULL) {
tl_debug(ctx->lib, "external memory deregister: mr %p", mr);
if (ibv_dereg_mr(mr)) {
tl_error(ctx->lib, "Couldn't destroy mr %p", mr);
tl_error(ctx->lib, "couldn't destroy mr %p", mr);
return UCC_ERR_NO_RESOURCE;
}
} else {
Expand All @@ -52,8 +54,8 @@ ucc_tl_mlx5_mcast_rcache_mem_reg_cb(void *context, ucs_rcache_t *rcache,
length = (size_t)(rregion->super.end - rregion->super.start);
region = ucc_derived_of(rregion, ucc_tl_mlx5_mcast_rcache_region_t);

ret = ucc_tl_mlx5_mcast_coll_reg_mr((ucc_tl_mlx5_mcast_coll_context_t *)context, address,
length, &region->reg.mr);
ret = ucc_tl_mlx5_mcast_coll_reg_mr((ucc_tl_mlx5_mcast_coll_context_t *)context,
address, length, &region->reg.mr);
if (UCS_OK != ret) {
return ret;
}
Expand All @@ -62,17 +64,20 @@ ucc_tl_mlx5_mcast_rcache_mem_reg_cb(void *context, ucs_rcache_t *rcache,

}

static void ucc_tl_mlx5_mcast_rcache_mem_dereg_cb(void *context, ucc_rcache_t *rcache,
ucc_rcache_region_t *rregion)
static void ucc_tl_mlx5_mcast_rcache_mem_dereg_cb(void *context, ucc_rcache_t
*rcache, ucc_rcache_region_t *rregion)
{
ucc_tl_mlx5_mcast_rcache_region_t *region = ucc_derived_of(rregion,
ucc_tl_mlx5_mcast_rcache_region_t);

ucc_tl_mlx5_mcast_coll_dereg_mr((ucc_tl_mlx5_mcast_coll_context_t *)context, region->reg.mr);
ucc_tl_mlx5_mcast_coll_dereg_mr((ucc_tl_mlx5_mcast_coll_context_t *)context,
region->reg.mr);
}

static void ucc_tl_mlx5_mcast_rcache_dump_region_cb(void *context, ucc_rcache_t *rcache,
ucc_rcache_region_t *rregion, char *buf,
static void ucc_tl_mlx5_mcast_rcache_dump_region_cb(void *context,
ucc_rcache_t *rcache,
ucc_rcache_region_t *rregion,
char *buf,
size_t max)
{
ucc_tl_mlx5_mcast_rcache_region_t *region = ucc_derived_of(rregion,
Expand All @@ -82,8 +87,9 @@ static void ucc_tl_mlx5_mcast_rcache_dump_region_cb(void *context, ucc_rcache_t
}

ucc_status_t
ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, void *addr,
size_t length, ucc_tl_mlx5_mcast_reg_t **reg)
ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx,
void *addr, size_t length,
ucc_tl_mlx5_mcast_reg_t **reg)
{
ucc_rcache_region_t *rregion;
ucc_tl_mlx5_mcast_rcache_region_t *region;
Expand All @@ -94,8 +100,7 @@ ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, void *addr

ucc_assert(rcache != NULL);

status = ucc_rcache_get(rcache, (void *)addr, length, NULL,
&rregion);
status = ucc_rcache_get(rcache, (void *)addr, length, NULL, &rregion);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(ctx->lib, "ucc_rcache_get failed");
return status;
Expand All @@ -104,13 +109,14 @@ ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, void *addr
region = ucc_derived_of(rregion, ucc_tl_mlx5_mcast_rcache_region_t);
*reg = &region->reg;

tl_debug(ctx->lib, "memory register mr %p", (*reg)->mr);
tl_trace(ctx->lib, "memory register mr %p", (*reg)->mr);

return UCC_OK;
}

ucc_status_t
ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_tl_mlx5_mcast_reg_t *reg)
ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx,
ucc_tl_mlx5_mcast_reg_t *reg)
{
ucc_tl_mlx5_mcast_rcache_region_t *region;
ucc_rcache_t *rcache;
Expand All @@ -122,7 +128,7 @@ ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_tl_m
}

ucc_assert(rcache != NULL);
tl_debug(ctx->lib, "memory deregister mr %p", reg->mr);
tl_trace(ctx->lib, "memory deregister mr %p", reg->mr);
region = ucc_container_of(reg, ucc_tl_mlx5_mcast_rcache_region_t, reg);
ucc_rcache_region_put(rcache, &region->super);

Expand All @@ -145,7 +151,8 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_rcache(ucc_tl_mlx5_mcast_coll_context_t *ct
rcache_params.max_size = SIZE_MAX;
rcache_params.region_struct_size = sizeof(ucc_tl_mlx5_mcast_rcache_region_t);
rcache_params.max_alignment = ucc_get_page_size();
rcache_params.ucm_events = UCM_EVENT_VM_UNMAPPED | UCM_EVENT_MEM_TYPE_FREE;
rcache_params.ucm_events = UCM_EVENT_VM_UNMAPPED |
UCM_EVENT_MEM_TYPE_FREE;
rcache_params.context = ctx;
rcache_params.ops = &ucc_rcache_ops;
rcache_params.flags = 0;
Expand Down
5 changes: 3 additions & 2 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

ucc_status_t ucc_tl_mlx5_mcast_setup_rcache(ucc_tl_mlx5_mcast_coll_context_t *ctx);

ucc_status_t ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, void *addr,
size_t length, ucc_tl_mlx5_mcast_reg_t **reg);
ucc_status_t ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t
*ctx, void *addr, size_t length,
ucc_tl_mlx5_mcast_reg_t **reg);

ucc_status_t ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx,
ucc_tl_mlx5_mcast_reg_t *reg);

0 comments on commit 26c4912

Please sign in to comment.