Skip to content

Commit

Permalink
TL/MLX5: revise team and ctx init (#815)
Browse files Browse the repository at this point in the history
* TL/MLX5: revise team and ctx init

* TL/MLX5: change setup/cleanup error log to debug

* TL/MLX5: minor reviews

* CODESTYLE: clang-tidy

* TL/MLX5: alloc dm atomics before mpool

* CODESTYLE: clang-tidy

* TL/MLX5: fix bug with with socket

* REVIEW: minor comments

* TL/MLX5: disable coverity issue

* REVIEW: minor comments

* TL/MLX5: fix socket closing

* TL/MLX5: score map if a2a not avail

* CODESTYLE: clang format
  • Loading branch information
samnordmann authored Aug 22, 2023
1 parent b270265 commit cd7175c
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 299 deletions.
207 changes: 85 additions & 122 deletions src/components/tl/mlx5/alltoall/alltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,17 @@ static ucc_status_t build_rank_map(ucc_tl_mlx5_alltoall_t *a2a,
return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_team_alltoall_init_start(ucc_tl_mlx5_team_t *team)
ucc_status_t ucc_tl_mlx5_team_init_alltoall(ucc_tl_mlx5_team_t *team)
{
ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(team);
ucc_tl_mlx5_alltoall_t *a2a = NULL;
ucc_tl_mlx5_alltoall_t *a2a;
ucc_sbgp_t *node, *net;
size_t storage_size;
int i, j, node_size, ppn, team_size, nnodes;
ucc_topo_t *topo;
ucc_status_t status;

a2a = ucc_calloc(1, sizeof(*a2a), "mlx5_a2a");
if (!a2a) {
return UCC_ERR_NO_MEMORY;
}
team->a2a = NULL;
team->dm_ptr = NULL;
team->a2a_status.local = UCC_OK;

topo = team->topo;
node = ucc_topo_get_sbgp(topo, UCC_SBGP_NODE);
Expand All @@ -92,27 +89,22 @@ ucc_status_t ucc_tl_mlx5_team_alltoall_init_start(ucc_tl_mlx5_team_t *team)
"disabling mlx5 a2a for team with non-uniform ppn, "
"min_ppn %d, max_ppn %d",
ucc_topo_min_ppn(topo), ucc_topo_max_ppn(topo));
status = UCC_ERR_NOT_SUPPORTED;
goto err;
goto non_fatal_error;
}
ppn = ucc_topo_max_ppn(topo);

if (net->status == UCC_SBGP_NOT_EXISTS) {
tl_debug(ctx->super.super.lib,
"disabling mlx5 a2a for single node team");
status = UCC_ERR_NOT_SUPPORTED;
goto err;
goto non_fatal_error;
}

if (nnodes == team_size) {
tl_debug(ctx->super.super.lib,
"disabling mlx5 a2a for ppn=1 case, not supported so far");
status = UCC_ERR_NOT_SUPPORTED;
goto err;
goto non_fatal_error;
}

a2a->node_size = node_size;
ucc_assert(team_size == ppn * nnodes);

for (i = 0; i < nnodes; i++) {
for (j = 1; j < ppn; j++) {
Expand All @@ -121,38 +113,75 @@ ucc_status_t ucc_tl_mlx5_team_alltoall_init_start(ucc_tl_mlx5_team_t *team)
tl_debug(ctx->super.super.lib,
"disabling mlx5 a2a for team with non contiguous "
"ranks-per-node placement");
status = UCC_ERR_NOT_SUPPORTED;
goto err;
goto non_fatal_error;
}
}
}

a2a->pd = ctx->shared_pd;
a2a->ctx = ctx->shared_ctx;
a2a->ib_port = ctx->ib_port;
a2a->node.sbgp = node;
a2a->net.sbgp = net;
a2a->node.asr_rank = MLX5_ASR_RANK;
a2a->num_dci_qps = UCC_TL_MLX5_TEAM_LIB(team)->cfg.num_dci_qps;
a2a->sequence_number = 1;
a2a->net.ctrl_mr = NULL;
a2a->net.remote_ctrl = NULL;
a2a->net.rank_map = NULL;
a2a->max_msg_size = MAX_MSG_SIZE;
a2a->max_num_of_columns =
team->a2a = ucc_calloc(1, sizeof(*team->a2a), "mlx5_a2a");
if (!team->a2a) {
return UCC_ERR_NO_MEMORY;
}

a2a = team->a2a;
a2a->node_size = node_size;
a2a->pd = ctx->shared_pd;
a2a->ctx = ctx->shared_ctx;
a2a->ib_port = ctx->ib_port;
a2a->node.sbgp = node;
a2a->net.sbgp = net;
a2a->node.asr_rank = MLX5_ASR_RANK;
a2a->num_dci_qps = UCC_TL_MLX5_TEAM_LIB(team)->cfg.num_dci_qps;
a2a->sequence_number = 1;
a2a->net.atomic.counters = NULL;
a2a->net.ctrl_mr = NULL;
a2a->net.remote_ctrl = NULL;
a2a->net.rank_map = NULL;
a2a->max_msg_size = MAX_MSG_SIZE;
a2a->max_num_of_columns =
ucc_div_round_up(node->group_size, 2 /* todo: there can be an estimation of
minimal possible block size */);

ucc_assert(a2a->net.sbgp->status == UCC_SBGP_ENABLED ||
node->group_rank != 0);

if (a2a->node.asr_rank == node->group_rank) {
team->a2a_status.local = ucc_tl_mlx5_dm_init(team);
if (UCC_OK != team->a2a_status.local) {
tl_debug(UCC_TL_TEAM_LIB(team), "failed to init device memory");
}
}

return UCC_OK;

non_fatal_error:
team->a2a_status.local = UCC_ERR_NOT_SUPPORTED;
return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_team_test_alltoall_start(ucc_tl_mlx5_team_t *team)
{
ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(team);
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
size_t storage_size;

if (team->a2a_status.global != UCC_OK) {
tl_debug(ctx->super.super.lib, "global status in error state: %s",
ucc_status_string(team->a2a_status.global));

ucc_tl_mlx5_dm_cleanup(team);
if (a2a) {
ucc_free(a2a);
team->a2a = NULL;
}
ucc_tl_mlx5_topo_cleanup(team);
return team->a2a_status.global;
}

if (a2a->node.asr_rank == a2a->node.sbgp->group_rank) {
a2a->net.net_size = a2a->net.sbgp->group_size;
storage_size = OP_SEGMENT_SIZE(a2a) * MAX_OUTSTANDING_OPS;
a2a->bcast_data.shmid =
shmget(IPC_PRIVATE, storage_size, IPC_CREAT | 0600);
if (a2a->bcast_data.shmid == -1) {
tl_error(ctx->super.super.lib,
tl_debug(ctx->super.super.lib,
"failed to allocate sysv shm segment for %zd bytes",
storage_size);
} else {
Expand All @@ -166,72 +195,10 @@ ucc_status_t ucc_tl_mlx5_team_alltoall_init_start(ucc_tl_mlx5_team_t *team)
a2a->state = TL_MLX5_ALLTOALL_STATE_SHMID;

team->a2a = a2a;
return ucc_service_bcast(UCC_TL_CORE_TEAM(team), &a2a->bcast_data,
sizeof(ucc_tl_mlx5_a2a_bcast_data_t),
a2a->node.asr_rank, ucc_sbgp_to_subset(node),
&team->scoll_req);
err:
if (a2a) {
ucc_free(a2a);
}
return status;
}

static void ucc_tl_mlx5_alltoall_atomic_free(ucc_tl_mlx5_alltoall_t *a2a)
{
ibv_dereg_mr(a2a->net.atomic.mr);
#if ATOMIC_IN_MEMIC
ibv_free_dm(a2a->net.atomic.counters);
#else
ucc_free(a2a->net.atomic.counters);
#endif
}

static ucc_status_t ucc_tl_mlx5_alltoall_atomic_alloc(ucc_tl_mlx5_team_t *team)
{
ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(team);
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
size_t size;

size = sizeof(*a2a->net.atomic.counters) * MAX_OUTSTANDING_OPS;
#if ATOMIC_IN_MEMIC
struct ibv_alloc_dm_attr dm_attr;
memset(&dm_attr, 0, sizeof(dm_attr));
dm_attr.length = size;
a2a->net.atomic.counters = ibv_alloc_dm(ctx->shared_ctx, &dm_attr);
#else
a2a->net.atomic.counters = ucc_malloc(size, "atomic");
#endif

if (!a2a->net.atomic.counters) {
tl_debug(UCC_TL_TEAM_LIB(team),
"failed to allocate %zd bytes for atomic counters array",
size);
return UCC_ERR_NO_MEMORY;
}
#if ATOMIC_IN_MEMIC
a2a->net.atomic.mr =
ibv_reg_dm_mr(ctx->shared_pd, a2a->net.atomic.counters, 0, size,
IBV_ACCESS_REMOTE_ATOMIC | IBV_ACCESS_LOCAL_WRITE |
IBV_ACCESS_ZERO_BASED);

#else
a2a->net.atomic.mr =
ibv_reg_mr(ctx->shared_pd, a2a->net.atomic.counters, size,
IBV_ACCESS_REMOTE_ATOMIC | IBV_ACCESS_LOCAL_WRITE);
#endif

if (!a2a->net.atomic.mr) {
tl_error(UCC_TL_TEAM_LIB(team),
"failed to register atomic couters array");
#if ATOMIC_IN_MEMIC
ibv_free_dm(a2a->net.atomic.counters);
#else
ucc_free(a2a->net.atomic.counters);
#endif
return UCC_ERR_NO_MESSAGE;
}
return UCC_OK;
return ucc_service_bcast(
UCC_TL_CORE_TEAM(team), &a2a->bcast_data,
sizeof(ucc_tl_mlx5_a2a_bcast_data_t), a2a->node.asr_rank,
ucc_sbgp_to_subset(a2a->node.sbgp), &team->scoll_req);
}

static void ucc_tl_mlx5_alltoall_barrier_free(ucc_tl_mlx5_alltoall_t *a2a)
Expand Down Expand Up @@ -270,26 +237,29 @@ static ucc_status_t ucc_tl_mlx5_alltoall_barrier_alloc(ucc_tl_mlx5_team_t *team)
return UCC_OK;
}

ucc_status_t
ucc_tl_mlx5_team_alltoall_init_progress(ucc_tl_mlx5_team_t *tl_team)
ucc_status_t ucc_tl_mlx5_team_test_alltoall_progress(ucc_tl_mlx5_team_t *team)
{
ucc_tl_mlx5_team_t *team = ucc_derived_of(tl_team,
ucc_tl_mlx5_team_t);
ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(team);
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
ucc_rank_t node_size = a2a->node.sbgp->group_size;
ucc_rank_t node_rank = a2a->node.sbgp->group_rank;
ucc_base_lib_t *lib = UCC_TL_TEAM_LIB(team);
size_t op_seg_size = OP_SEGMENT_SIZE(a2a);
int i = 0;
net_exchange_t *local_data = NULL;
ucc_rank_t node_size, node_rank;
ucc_status_t status;
ucc_tl_mlx5_alltoall_op_t *op;
int j, asr_cq_size, net_size, ret;
struct ibv_port_attr port_attr;
size_t local_data_size, umr_buf_size;
size_t op_seg_size, local_data_size, umr_buf_size;
net_exchange_t *global_data, *remote_data;

if (team->a2a_status.local < 0) {
return team->a2a_status.local;
}

node_size = a2a->node.sbgp->group_size;
node_rank = a2a->node.sbgp->group_rank;
op_seg_size = OP_SEGMENT_SIZE(a2a);

switch (a2a->state) {
case TL_MLX5_ALLTOALL_STATE_SHMID:
status = ucc_service_coll_test(team->scoll_req);
Expand Down Expand Up @@ -335,11 +305,6 @@ ucc_tl_mlx5_team_alltoall_init_progress(ucc_tl_mlx5_team_t *tl_team)
return UCC_OK;
}

status = ucc_tl_mlx5_alltoall_atomic_alloc(team);
if (UCC_OK != status) {
goto err_atomic;
}

status = ucc_tl_mlx5_alltoall_barrier_alloc(team);
if (UCC_OK != status) {
goto err_barrier;
Expand Down Expand Up @@ -518,12 +483,12 @@ ucc_tl_mlx5_team_alltoall_init_progress(ucc_tl_mlx5_team_t *tl_team)
a2a->state = TL_MLX5_ALLTOALL_STATE_EXCHANGE_PROGRESS;

case TL_MLX5_ALLTOALL_STATE_EXCHANGE_PROGRESS:
status = ucc_service_coll_test(tl_team->scoll_req);
status = ucc_service_coll_test(team->scoll_req);
if (status < 0) {
tl_error(UCC_TL_TEAM_LIB(tl_team),
tl_error(UCC_TL_TEAM_LIB(team),
"failure during service coll exchange: %s",
ucc_status_string(status));
ucc_service_coll_finalize(tl_team->scoll_req);
ucc_service_coll_finalize(team->scoll_req);
goto err_service_allgather_progress;
}
if (UCC_INPROGRESS == status) {
Expand All @@ -534,7 +499,7 @@ ucc_tl_mlx5_team_alltoall_init_progress(ucc_tl_mlx5_team_t *tl_team)

case TL_MLX5_ALLTOALL_STATE_EXCHANGE_DONE:
local_data = team->scoll_req->data;
ucc_service_coll_finalize(tl_team->scoll_req);
ucc_service_coll_finalize(team->scoll_req);

net_size = a2a->net.net_size;
local_data_size = sizeof(net_exchange_t);
Expand Down Expand Up @@ -691,8 +656,6 @@ ucc_tl_mlx5_team_alltoall_init_progress(ucc_tl_mlx5_team_t *tl_team)
err_blocks_sent:
ucc_tl_mlx5_alltoall_barrier_free(a2a);
err_barrier:
ucc_tl_mlx5_alltoall_atomic_free(a2a);
err_atomic:
return status;
}

Expand Down Expand Up @@ -721,6 +684,7 @@ void ucc_tl_mlx5_alltoall_cleanup(ucc_tl_mlx5_team_t *team)
for (i = 0; i < a2a->num_dci_qps; i++) {
ibv_destroy_qp(a2a->net.dcis[i].dci_qp);
}
ucc_free(a2a->net.dcis);
ibv_destroy_qp(a2a->net.dct_qp);
ibv_destroy_srq(a2a->net.srq);
for (i = 0; i < a2a->net.net_size; i++) {
Expand Down Expand Up @@ -753,7 +717,6 @@ void ucc_tl_mlx5_alltoall_cleanup(ucc_tl_mlx5_team_t *team)

ucc_free(a2a->net.blocks_sent);
ucc_tl_mlx5_alltoall_barrier_free(a2a);
ucc_tl_mlx5_alltoall_atomic_free(a2a);
}
ucc_free(a2a->net.dcis);
ucc_free(a2a);
}
7 changes: 5 additions & 2 deletions src/components/tl/mlx5/alltoall/alltoall.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "tl_mlx5.h"
#include "tl_mlx5_ib.h"
#include "tl_mlx5_dm.h"

#define SEQ_INDEX(_seq_num) ((_seq_num) % MAX_OUTSTANDING_OPS)

Expand Down Expand Up @@ -136,8 +137,10 @@ typedef struct ucc_tl_mlx5_alltoall {
ucc_tl_mlx5_a2a_bcast_data_t bcast_data;
} ucc_tl_mlx5_alltoall_t;

ucc_status_t ucc_tl_mlx5_team_alltoall_init_start(ucc_tl_mlx5_team_t *team);
ucc_status_t ucc_tl_mlx5_team_alltoall_init_progress(ucc_tl_mlx5_team_t *team);
void ucc_tl_mlx5_topo_cleanup(ucc_tl_mlx5_team_t *team);
ucc_status_t ucc_tl_mlx5_team_init_alltoall(ucc_tl_mlx5_team_t *team);
ucc_status_t ucc_tl_mlx5_team_test_alltoall_start(ucc_tl_mlx5_team_t *team);
ucc_status_t ucc_tl_mlx5_team_test_alltoall_progress(ucc_tl_mlx5_team_t *team);
ucc_status_t ucc_tl_mlx5_alltoall_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h);
Expand Down
8 changes: 7 additions & 1 deletion src/components/tl/mlx5/tl_mlx5.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ typedef struct ucc_tl_mlx5_context {
ucc_rcache_t *rcache;
int is_imported;
int ib_port;
int sock;
ucc_mpool_t req_mp;
ucc_tl_mlx5_mcast_context_t mcast;
} ucc_tl_mlx5_context_t;
Expand All @@ -108,15 +109,20 @@ typedef enum
TL_MLX5_TEAM_STATE_ALLTOALL_POSTED
} ucc_tl_mlx5_team_state_t;

typedef struct ucc_tl_mlx5_team_status {
ucc_status_t local;
ucc_status_t global;
} ucc_tl_mlx5_team_status_t;

typedef struct ucc_tl_mlx5_team {
ucc_tl_team_t super;
ucc_status_t status[2];
ucc_service_coll_req_t *scoll_req;
ucc_tl_mlx5_team_state_t state;
void *dm_offset;
ucc_mpool_t dm_pool;
struct ibv_dm *dm_ptr;
struct ibv_mr *dm_mr;
ucc_tl_mlx5_team_status_t a2a_status;
ucc_tl_mlx5_alltoall_t *a2a;
ucc_topo_t *topo;
ucc_ep_map_t ctx_map;
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args,
if (ucc_unlikely(UCC_OK != status)) {
goto free_task;
}

*task_h = &(task->super);

tl_debug(UCC_TASK_LIB(task), "init coll task %p", task);
Expand Down
Loading

0 comments on commit cd7175c

Please sign in to comment.