Skip to content

Commit

Permalink
TL/MLX5: MCAST HCA selection (openucx#942)
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB authored Mar 18, 2024
1 parent 0d68445 commit dbc6635
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_context {
struct ibv_pd *pd;
char *devname;
int max_qp_wr;
int user_provided_ib;
int ib_port;
int pkey_index;
int mtu;
Expand Down
47 changes: 31 additions & 16 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont
int is_ipv4 = 0;
struct sockaddr_in *in_src_addr = NULL;
struct rdma_cm_event *revent = NULL;
char *ib = NULL;
char *ib_name = NULL;
char *port = NULL;
int active_mtu = 4096;
int max_mtu = 4096;
ucc_tl_mlx5_mcast_coll_context_t *ctx = NULL;
char *ib_devname = NULL;
int devname_len = 0, ib_port = 1;
struct ibv_port_attr port_attr;
struct ibv_device_attr device_attr;
struct sockaddr_storage ip_oib_addr;
Expand All @@ -47,9 +46,9 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont
ucc_tl_mlx5_context_t *mlx5_ctx;
ucc_base_lib_t *lib;
int i;
int user_provided_ib;
int ib_valid;
const char *dst;
char tmp[128], *pos, *end_pos;

mlx5_ctx = ucc_container_of(context, ucc_tl_mlx5_context_t, mcast);
lib = mlx5_ctx->super.super.lib;
Expand Down Expand Up @@ -86,12 +85,34 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont
memset(ctx->devname, 0, strlen(devname)+3);
memcpy(ctx->devname, devname, strlen(devname));
strncat(ctx->devname, ":1", 3);
user_provided_ib = 0;
ctx->user_provided_ib = 0;
ctx->ib_port = 1;
} else {
ib_valid = 0;
/* user has provided the devname now make sure it is valid */
/* check if port number is also included and extract devname from user str */
ib_devname = mcast_ctx_conf->ib_dev_name;
pos = strstr(ib_devname, ":");
if (!pos) {
devname_len = sizeof(tmp) - 1;
} else {
devname_len = (int)(pos - ib_devname);
pos++;
errno = 0;
ib_port = (int)strtol(pos, &end_pos, 0);
if (errno != 0 || pos == end_pos || strcmp(end_pos,"\0") || ib_port < 0
|| ib_port > UINT8_MAX ) {
tl_warn(lib, "wrong device's port number");
return UCC_ERR_INVALID_PARAM;
}
}
ctx->ib_port = ib_port;
strncpy(tmp, ib_devname, devname_len);
tmp[devname_len] = '\0';
ib_devname = tmp;

for (i = 0; device_list[i]; ++i) {
if (!strcmp(ibv_get_device_name(device_list[i]), mcast_ctx_conf->ib_dev_name)) {
if (!strcmp(ibv_get_device_name(device_list[i]), ib_devname)) {
ib_valid = 1;
break;
}
Expand All @@ -102,16 +123,16 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont
ibv_free_device_list(device_list);
goto error;
}
ctx->devname = mcast_ctx_conf->ib_dev_name;
user_provided_ib = 1;
ctx->devname = mcast_ctx_conf->ib_dev_name;
ctx->user_provided_ib = 1;
}

ibv_free_device_list(device_list);

status = ucc_tl_mlx5_probe_ip_over_ib(ctx->devname, &ip_oib_addr);
if (UCC_OK != status) {
tl_debug(lib, "failed to get ipoib interface for devname %s", ctx->devname);
if (!user_provided_ib) {
if (!ctx->user_provided_ib) {
ucc_free(ctx->devname);
}
goto error;
Expand Down Expand Up @@ -179,19 +200,13 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont
goto error;
}

ib = strdup(ctx->devname);
ucc_string_split(ib, ":", 2, &ib_name, &port);
ctx->ib_port = atoi(port);
ucc_free(ib);

/* Determine MTU */
if (ibv_query_port(ctx->ctx, ctx->ib_port, &port_attr)) {
tl_error(lib, "couldn't query port in ctx create, errno %d", errno);
status = UCC_ERR_NO_RESOURCE;
goto error;
}


for (i = 0; i < UCC_TL_MLX5_MCAST_MAX_MTU_COUNT; i++) {
if (mtu_lookup[i][1] == port_attr.max_mtu) {
max_mtu = mtu_lookup[i][0];
Expand Down Expand Up @@ -292,7 +307,7 @@ ucc_status_t ucc_tl_mlx5_mcast_clean_ctx(ucc_tl_mlx5_mcast_coll_context_t *ctx)
ctx->channel = NULL;
}

if (ctx->devname && !strcmp(ctx->params.ib_dev_name, "")) {
if (ctx->devname && !ctx->user_provided_ib) {
ucc_free(ctx->devname);
ctx->devname = NULL;
}
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
4 changes: 3 additions & 1 deletion src/components/tl/mlx5/tl_mlx5_coll.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -16,6 +16,7 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args,
ucc_tl_mlx5_task_t *task = NULL;

if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) {
tl_trace(team->context->lib, "mcast bcast not supported for active sets");
return UCC_ERR_NOT_SUPPORTED;
}

Expand Down Expand Up @@ -51,6 +52,7 @@ ucc_status_t ucc_tl_mlx5_task_finalize(ucc_coll_task_t *coll_task)
if (req != NULL) {
ucc_assert(coll_task->status != UCC_INPROGRESS);
ucc_free(req);
tl_trace(UCC_TASK_LIB(task), "finalizing an mcast task %p", task);
task->bcast_mcast.req_handle = NULL;
}

Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/mlx5/tl_mlx5_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team,
team_info.num_mem_types = 2;
team_info.supported_mem_types = mt;
team_info.supported_colls =
(UCC_COLL_TYPE_ALLTOALL * (team->a2a_status.local == UCC_OK)) |
UCC_COLL_TYPE_BCAST;
(UCC_COLL_TYPE_ALLTOALL * (team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY)) |
UCC_COLL_TYPE_BCAST * (team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY);
team_info.size = UCC_TL_TEAM_SIZE(team);

status = ucc_coll_score_build_default(
Expand Down

0 comments on commit dbc6635

Please sign in to comment.