diff --git a/src/components/tl/mlx5/Makefile.am b/src/components/tl/mlx5/Makefile.am index b406feac41..45819da272 100644 --- a/src/components/tl/mlx5/Makefile.am +++ b/src/components/tl/mlx5/Makefile.am @@ -17,7 +17,9 @@ sources = \ tl_mlx5_wqe.c \ tl_mlx5_pd.h \ tl_mlx5_pd.c \ - tl_mlx5_rcache.c + tl_mlx5_rcache.c \ + tl_mlx5_dm.c \ + tl_mlx5_dm.h module_LTLIBRARIES = libucc_tl_mlx5.la libucc_tl_mlx5_la_SOURCES = $(sources) diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 50a30304f4..8c51990289 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -84,15 +84,33 @@ typedef struct ucc_tl_mlx5_context { UCC_CLASS_DECLARE(ucc_tl_mlx5_context_t, const ucc_base_context_params_t *, const ucc_base_config_t *); +typedef struct ucc_tl_mlx5_schedule ucc_tl_mlx5_schedule_t; +typedef struct ucc_tl_mlx5_dm_chunk { + ptrdiff_t offset; /* 0 based offset from the beginning of + memic_mr (obtained with ibv_reg_dm_mr) */ + ucc_tl_mlx5_schedule_t *task; +} ucc_tl_mlx5_dm_chunk_t; + typedef struct ucc_tl_mlx5_a2a ucc_tl_mlx5_a2a_t; + +typedef enum +{ + TL_MLX5_TEAM_STATE_INIT, + TL_MLX5_TEAM_STATE_POSTED, +} ucc_tl_mlx5_team_state_t; + typedef struct ucc_tl_mlx5_team { - ucc_tl_team_t super; - ucc_service_coll_req_t *scoll_req; - void * oob_req; - ucc_mpool_t dm_pool; - struct ibv_dm * dm_ptr; - struct ibv_mr * dm_mr; - ucc_tl_mlx5_a2a_t * a2a; + ucc_tl_team_t super; + ucc_status_t status[2]; + ucc_service_coll_req_t *scoll_req; + ucc_tl_mlx5_team_state_t state; + void *dm_offset; + ucc_mpool_t dm_pool; + struct ibv_dm *dm_ptr; + struct ibv_mr *dm_mr; + ucc_tl_mlx5_a2a_t *a2a; + ucc_topo_t *topo; + ucc_ep_map_t ctx_map; } ucc_tl_mlx5_team_t; UCC_CLASS_DECLARE(ucc_tl_mlx5_team_t, ucc_base_context_t *, const ucc_base_team_params_t *); diff --git a/src/components/tl/mlx5/tl_mlx5_dm.c b/src/components/tl/mlx5/tl_mlx5_dm.c new file mode 100644 index 0000000000..a311e9a4a7 --- /dev/null +++ b/src/components/tl/mlx5/tl_mlx5_dm.c @@ -0,0 +1,172 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_dm.h" + +#define DM_HOST_AUTO_NUM_CHUNKS 8 + +static void ucc_tl_mlx5_dm_chunk_init(ucc_mpool_t *mp, //NOLINT + void *obj, void *chunk) //NOLINT +{ + ucc_tl_mlx5_dm_chunk_t *c = (ucc_tl_mlx5_dm_chunk_t *)obj; + ucc_tl_mlx5_team_t *team = + ucc_container_of(mp, ucc_tl_mlx5_team_t, dm_pool); + + c->offset = (ptrdiff_t)team->dm_offset; + team->dm_offset = PTR_OFFSET(team->dm_offset, + UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size); +} + +static void ucc_tl_mlx5_dm_chunk_release(ucc_mpool_t *mp, void *chunk) //NOLINT +{ + ucc_free(chunk); +} + +static ucc_mpool_ops_t ucc_tl_mlx5_dm_ops = {.chunk_alloc = ucc_mpool_hugetlb_malloc, + .chunk_release = + ucc_tl_mlx5_dm_chunk_release, + .obj_init = ucc_tl_mlx5_dm_chunk_init, + .obj_cleanup = NULL}; + +void ucc_tl_mlx5_dm_cleanup(ucc_tl_mlx5_team_t *team) +{ + if (!team->dm_ptr) { + return; + } + + ucc_mpool_cleanup(&team->dm_pool, 1); + + ibv_dereg_mr(team->dm_mr); + if (UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_host) { + ucc_free(team->dm_ptr); + } else { + ibv_free_dm(team->dm_ptr); + } +} + +ucc_status_t ucc_tl_mlx5_dm_alloc_reg(struct ibv_context *ib_ctx, + struct ibv_pd *pd, int dm_host, + size_t buf_size, size_t *buf_num_p, + struct ibv_dm **ptr, struct ibv_mr **mr, + ucc_base_lib_t *lib) +{ + struct ibv_dm *dm_ptr = NULL; + struct ibv_mr *dm_mr; + struct ibv_device_attr_ex attr; + struct ibv_alloc_dm_attr dm_attr; + int max_chunks_to_alloc, min_chunks_to_alloc, i; + + if (dm_host) { + max_chunks_to_alloc = (*buf_num_p == UCC_ULUNITS_AUTO) + ? DM_HOST_AUTO_NUM_CHUNKS + : *buf_num_p; + dm_attr.length = max_chunks_to_alloc * buf_size; + dm_ptr = ucc_malloc(dm_attr.length, "memic_host"); + if (!dm_ptr) { + tl_error(lib, " memic_host allocation failed"); + return UCC_ERR_NO_MEMORY; + } + + dm_mr = ibv_reg_mr(pd, dm_ptr, dm_attr.length, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!dm_mr) { + tl_error(lib, "failed to reg host memory"); + ucc_free(dm_ptr); + return UCC_ERR_NO_MESSAGE; + } + *buf_num_p = max_chunks_to_alloc; + } else { + attr.comp_mask = 0; + if (ibv_query_device_ex(ib_ctx, NULL, &attr)) { + tl_error(lib, "failed to query device (errno=%d)", errno); + return UCC_ERR_NO_MESSAGE; + } + if (!attr.max_dm_size) { + tl_error(lib, "device doesn't support dm allocation"); + return UCC_ERR_NO_RESOURCE; + } + max_chunks_to_alloc = min_chunks_to_alloc = *buf_num_p; + if (*buf_num_p == UCC_ULUNITS_AUTO) { + max_chunks_to_alloc = + attr.max_dm_size / buf_size - 1; //keep reserved memory + min_chunks_to_alloc = 1; + if (!max_chunks_to_alloc) { + tl_error(lib, + "requested buffer size (=%ld) is too large, " + "should be set to be strictly less than %ld. " + "max allocation size is %ld", + buf_size, attr.max_dm_size / 2, attr.max_dm_size); + return UCC_ERR_NO_RESOURCE; + } + } + if (attr.max_dm_size < buf_size * min_chunks_to_alloc) { + tl_error(lib, + "cannot allocate %i buffer(s) of size %ld, " + "max allocation size is %ld", + min_chunks_to_alloc, buf_size, attr.max_dm_size); + return UCC_ERR_NO_MEMORY; + } + memset(&dm_attr, 0, sizeof(dm_attr)); + for (i = max_chunks_to_alloc; i >= min_chunks_to_alloc; i--) { + dm_attr.length = i * buf_size; + errno = 0; + dm_ptr = ibv_alloc_dm(ib_ctx, &dm_attr); + if (dm_ptr) { + break; + } + } + if (!dm_ptr) { + tl_error(lib, + "dev mem allocation failed, requested %ld, attr.max %zd, " + "errno %d", + dm_attr.length, attr.max_dm_size, errno); + return errno == ENOMEM || errno == ENOSPC ? UCC_ERR_NO_MEMORY + : UCC_ERR_NO_MESSAGE; + } + dm_mr = ibv_reg_dm_mr(pd, dm_ptr, 0, dm_attr.length, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_ZERO_BASED); + if (!dm_mr) { + tl_error(lib, "failed to reg memic"); + ibv_free_dm(dm_ptr); + return UCC_ERR_NO_MESSAGE; + } + *buf_num_p = i; + } + *ptr = dm_ptr; + *mr = dm_mr; + + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_dm_init(ucc_tl_mlx5_team_t *team) +{ + ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(team); + ucc_tl_mlx5_lib_config_t *cfg = &UCC_TL_MLX5_TEAM_LIB(team)->cfg; + ucc_status_t status; + + status = ucc_tl_mlx5_dm_alloc_reg( + ctx->shared_ctx, ctx->shared_pd, cfg->dm_host, cfg->dm_buf_size, + &cfg->dm_buf_num, &team->dm_ptr, &team->dm_mr, UCC_TL_TEAM_LIB(team)); + if (status != UCC_OK) { + tl_error(UCC_TL_TEAM_LIB(team), + "failed to alloc and register device memory"); + return status; + } + team->dm_offset = NULL; + + status = ucc_mpool_init(&team->dm_pool, 0, sizeof(ucc_tl_mlx5_dm_chunk_t), + 0, UCC_CACHE_LINE_SIZE, cfg->dm_buf_num, + cfg->dm_buf_num, &ucc_tl_mlx5_dm_ops, + ctx->super.super.ucc_context->thread_mode, + "mlx5 dm pool"); + if (status != UCC_OK) { + tl_error(UCC_TL_TEAM_LIB(team), "failed to init dm pool"); + ucc_tl_mlx5_dm_cleanup(team); + return status; + } + return UCC_OK; +} diff --git a/src/components/tl/mlx5/tl_mlx5_dm.h b/src/components/tl/mlx5/tl_mlx5_dm.h new file mode 100644 index 0000000000..3b611e44b3 --- /dev/null +++ b/src/components/tl/mlx5/tl_mlx5_dm.h @@ -0,0 +1,17 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5.h" + +ucc_status_t ucc_tl_mlx5_dm_alloc_reg(struct ibv_context *ib_ctx, + struct ibv_pd *pd, int dm_host, + size_t buf_size, size_t *buf_num_p, + struct ibv_dm **ptr, struct ibv_mr **mr, + ucc_base_lib_t *lib); + +void ucc_tl_mlx5_dm_cleanup(ucc_tl_mlx5_team_t *team); + +ucc_status_t ucc_tl_mlx5_dm_init(ucc_tl_mlx5_team_t *team); diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index c1723b52ac..8c1c1dc439 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -5,10 +5,44 @@ */ #include "tl_mlx5.h" +#include "tl_mlx5_dm.h" #include "coll_score/ucc_coll_score.h" #include "core/ucc_team.h" #include +static ucc_status_t ucc_tl_mlx5_topo_init(ucc_tl_mlx5_team_t *team) +{ + ucc_subset_t subset; + ucc_status_t status; + + status = ucc_ep_map_create_nested(&UCC_TL_CORE_TEAM(team)->ctx_map, + &UCC_TL_TEAM_MAP(team), &team->ctx_map); + if (UCC_OK != status) { + tl_error(UCC_TL_TEAM_LIB(team), "failed to create ctx map"); + return status; + } + subset.map = team->ctx_map; + subset.myrank = UCC_TL_TEAM_RANK(team); + + status = ucc_topo_init(subset, UCC_TL_CORE_CTX(team)->topo, &team->topo); + + if (UCC_OK != status) { + tl_error(UCC_TL_TEAM_LIB(team), "failed to init team topo"); + goto err_topo_init; + } + + return UCC_OK; +err_topo_init: + ucc_ep_map_destroy_nested(&team->ctx_map); + return status; +} + +static void ucc_tl_mlx5_topo_cleanup(ucc_tl_mlx5_team_t *team) +{ + ucc_ep_map_destroy_nested(&team->ctx_map); + ucc_topo_cleanup(team->topo); +} + UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_team_t, ucc_base_context_t *tl_context, const ucc_base_team_params_t *params) { @@ -20,12 +54,33 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_team_t, ucc_base_context_t *tl_context, self->a2a = NULL; self->dm_ptr = NULL; - return status; + + status = ucc_tl_mlx5_topo_init(self); + if (status != UCC_OK) { + tl_error(ctx->super.super.lib, "failed to init team topo"); + return status; + } + + if (ucc_topo_get_sbgp(self->topo, UCC_SBGP_NODE)->group_rank == 0) { + status = ucc_tl_mlx5_dm_init(self); + if (UCC_OK != status) { + tl_error(UCC_TL_TEAM_LIB(self), "failed to init device memory"); + } + } + + self->status[0] = status; + self->state = TL_MLX5_TEAM_STATE_INIT; + + tl_debug(tl_context->lib, "posted tl team: %p", self); + return UCC_OK; } UCC_CLASS_CLEANUP_FUNC(ucc_tl_mlx5_team_t) { tl_debug(self->super.super.context->lib, "finalizing tl team: %p", self); + + ucc_tl_mlx5_dm_cleanup(self); + ucc_tl_mlx5_topo_cleanup(self); } UCC_CLASS_DEFINE_DELETE_FUNC(ucc_tl_mlx5_team_t, ucc_base_team_t); @@ -37,8 +92,48 @@ ucc_status_t ucc_tl_mlx5_team_destroy(ucc_base_team_t *tl_team) return UCC_OK; } -ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *tl_team) /* NOLINT */ +ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) { + ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_team_t *core_team = UCC_TL_CORE_TEAM(tl_team); + ucc_subset_t subset = {.map.type = UCC_EP_MAP_FULL, + .map.ep_num = core_team->size, + .myrank = core_team->rank}; + ucc_status_t status; + + switch (tl_team->state) { + case TL_MLX5_TEAM_STATE_INIT: + status = ucc_service_allreduce( + core_team, &tl_team->status[0], &tl_team->status[1], + UCC_DT_INT32, 1, UCC_OP_MIN, subset, &tl_team->scoll_req); + if (status < 0) { + tl_error(UCC_TL_TEAM_LIB(tl_team), + "failed to collect global status"); + return status; + } + tl_team->state = TL_MLX5_TEAM_STATE_POSTED; + case TL_MLX5_TEAM_STATE_POSTED: + status = ucc_service_coll_test(tl_team->scoll_req); + if (status < 0) { + tl_error(UCC_TL_TEAM_LIB(tl_team), + "failure during service coll exchange: %s", + ucc_status_string(status)); + return status; + } + if (UCC_INPROGRESS == status) { + return status; + } + ucc_assert(status == UCC_OK); + ucc_service_coll_finalize(tl_team->scoll_req); + if (tl_team->status[1] != UCC_OK) { + tl_error(UCC_TL_TEAM_LIB(tl_team), + "node leader failed during device memory init: %s", + ucc_status_string(tl_team->status[1])); + ucc_tl_mlx5_team_destroy(team); + return tl_team->status[1]; + } + } + return UCC_OK; } diff --git a/src/components/tl/mlx5/tl_mlx5_wqe.c b/src/components/tl/mlx5/tl_mlx5_wqe.c index 235bde9855..e8e36a27f6 100644 --- a/src/components/tl/mlx5/tl_mlx5_wqe.c +++ b/src/components/tl/mlx5/tl_mlx5_wqe.c @@ -297,7 +297,7 @@ ucc_status_t ucc_tl_mlx5_post_wait_on_data(struct ibv_qp *qp, uint64_t value, wseg->lkey = htobe32(lkey); wseg->va_fail = htobe64((addr) | (ACTION)); wseg->data = value; - wseg->data_mask = 1; + wseg->data_mask = 0xFFFFFFFF; mlx5dv_wr_raw_wqe(mqp, wqe_desc); if (ibv_wr_complete(qp_ex)) { return UCC_ERR_NO_MESSAGE; diff --git a/test/gtest/tl/mlx5/test_tl_mlx5.h b/test/gtest/tl/mlx5/test_tl_mlx5.h index a6f2e3c7f8..593dc049e8 100644 --- a/test/gtest/tl/mlx5/test_tl_mlx5.h +++ b/test/gtest/tl/mlx5/test_tl_mlx5.h @@ -8,6 +8,7 @@ #include #include "common/test_ucc.h" #include "components/tl/mlx5/tl_mlx5.h" +#include "components/tl/mlx5/tl_mlx5_dm.h" #include "components/tl/mlx5/tl_mlx5_ib.h" typedef ucc_status_t (*ucc_tl_mlx5_create_ibv_ctx_fn_t)( diff --git a/test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc b/test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc index 4f4ea44ec2..5207d6b111 100644 --- a/test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc +++ b/test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc @@ -6,8 +6,7 @@ #include "test_tl_mlx5_wqe.h" #include "utils/arch/cpu.h" #include - -#define DT uint8_t +#include // Rounds up a given integer to the closet power of two static int roundUpToPowerOfTwo(int a) @@ -92,54 +91,151 @@ INSTANTIATE_TEST_SUITE_P(, test_tl_mlx5_transpose, ::testing::Values(1, 5, 32, 64), ::testing::Values(1, 3, 8, 128))); -UCC_TEST_P(test_tl_mlx5_rdma_write, rdmaWriteWqe) +UCC_TEST_P(test_tl_mlx5_rdma_write, RdmaWriteWqe) { - int msgsize = GetParam(); - int completions_num = 0; - DT src[msgsize], dst[msgsize]; - struct ibv_wc wcs[1]; - struct ibv_mr *src_mr, *dst_mr; - int i; + struct ibv_sge sg; + struct ibv_send_wr wr; - for (i = 0; i < msgsize; i++) { - src[i] = i % 256; - dst[i] = 0; - } + bufsize = GetParam(); + buffers_init(); - src_mr = ibv_reg_mr(pd, src, msgsize * sizeof(DT), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); - GTEST_ASSERT_NE(nullptr, src_mr); - dst_mr = ibv_reg_mr(pd, dst, msgsize * sizeof(DT), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); - GTEST_ASSERT_NE(nullptr, dst_mr); + memset(&sg, 0, sizeof(sg)); + sg.addr = (uintptr_t)src; + sg.length = bufsize; + sg.lkey = src_mr->lkey; + + memset(&wr, 0, sizeof(wr)); + wr.wr_id = 0; + wr.sg_list = &sg; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_WRITE; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_FENCE; + wr.next = NULL; + wr.wr.rdma.remote_addr = (uintptr_t)dst; + wr.wr.rdma.rkey = dst_mr->rkey; + + // This work request is posted with wr_id = 0 + GTEST_ASSERT_EQ(ibv_post_send(qp.qp, &wr, NULL), 0); + wait_for_completion(); + + validate_buffers(); +} + +UCC_TEST_P(test_tl_mlx5_rdma_write, CustomRdmaWriteWqe) +{ + bufsize = GetParam(); + buffers_init(); ibv_wr_start(qp.qp_ex); - post_rdma_write(qp.qp, qpn, nullptr, (uintptr_t)src, msgsize * sizeof(DT), - src_mr->lkey, (uintptr_t)dst, dst_mr->rkey, + post_rdma_write(qp.qp, qpn, nullptr, (uintptr_t)src, bufsize, src_mr->lkey, + (uintptr_t)dst, dst_mr->rkey, IBV_SEND_SIGNALED | IBV_SEND_FENCE, 0); GTEST_ASSERT_EQ(ibv_wr_complete(qp.qp_ex), 0); + wait_for_completion(); - while (!completions_num) { - completions_num = ibv_poll_cq(cq, 1, wcs); - } - GTEST_ASSERT_EQ(completions_num, 1); - GTEST_ASSERT_EQ(wcs[0].status, IBV_WC_SUCCESS); - GTEST_ASSERT_EQ(wcs[0].wr_id, 0); + validate_buffers(); +} - //validation - for (i = 0; i < msgsize; i++) { - GTEST_ASSERT_EQ(src[i], dst[i]); +INSTANTIATE_TEST_SUITE_P(, test_tl_mlx5_rdma_write, + ::testing::Values(1, 31, 128, 1024)); + +UCC_TEST_P(test_tl_mlx5_dm, MemcpyToDeviceMemory) +{ + bufsize = GetParam(); + buffers_init(); + + if (bufsize % 4 != 0) { + GTEST_SKIP() << "for memcpy involving device memory, buffer size " + << "must be a multiple of 4"; } - GTEST_ASSERT_EQ(ibv_dereg_mr(src_mr), UCC_OK); - GTEST_ASSERT_EQ(ibv_dereg_mr(dst_mr), UCC_OK); + GTEST_ASSERT_EQ(ibv_memcpy_to_dm(dm_ptr, 0, (void *)src, bufsize), 0); + GTEST_ASSERT_EQ(ibv_memcpy_from_dm((void *)dst, dm_ptr, 0, bufsize), 0); + + validate_buffers(); } -INSTANTIATE_TEST_SUITE_P(, test_tl_mlx5_rdma_write, - ::testing::Values(1, 8, 128, 1024)); +UCC_TEST_P(test_tl_mlx5_dm, RdmaToDeviceMemory) +{ + struct ibv_sge sg; + struct ibv_send_wr wr; -UCC_TEST_F(test_tl_mlx5_wait_on_data, waitOnDataWqe) + bufsize = GetParam(); + buffers_init(); + + // RDMA write from host source to device memory + memset(&sg, 0, sizeof(sg)); + sg.addr = (uintptr_t)src; + sg.length = bufsize; + sg.lkey = src_mr->lkey; + + memset(&wr, 0, sizeof(wr)); + wr.wr_id = 0; + wr.sg_list = &sg; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_WRITE; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_FENCE; + wr.next = NULL; + wr.wr.rdma.remote_addr = (uintptr_t)0; + wr.wr.rdma.rkey = dm_mr->rkey; + + GTEST_ASSERT_EQ(ibv_post_send(qp.qp, &wr, NULL), 0); + wait_for_completion(); + + // RDMA write from device memory to host destination + memset(&sg, 0, sizeof(sg)); + sg.addr = (uintptr_t)0; + sg.length = bufsize; + sg.lkey = dm_mr->lkey; + + memset(&wr, 0, sizeof(wr)); + wr.wr_id = 0; + wr.sg_list = &sg; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_WRITE; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_FENCE; + wr.next = NULL; + wr.wr.rdma.remote_addr = (uintptr_t)dst; + wr.wr.rdma.rkey = dst_mr->rkey; + + GTEST_ASSERT_EQ(ibv_post_send(qp.qp, &wr, NULL), 0); + wait_for_completion(); + + validate_buffers(); +} + +UCC_TEST_P(test_tl_mlx5_dm, CustomRdmaToDeviceMemory) +{ + bufsize = GetParam(); + buffers_init(); + + // RDMA write from host source to device memory + ibv_wr_start(qp.qp_ex); + post_rdma_write(qp.qp, qpn, nullptr, (uintptr_t)src, bufsize, src_mr->lkey, + (uintptr_t)0, dm_mr->rkey, + IBV_SEND_SIGNALED | IBV_SEND_FENCE, 0); + GTEST_ASSERT_EQ(ibv_wr_complete(qp.qp_ex), 0); + wait_for_completion(); + + // RDMA write from device memory to host destination + ibv_wr_start(qp.qp_ex); + post_rdma_write(qp.qp, qpn, nullptr, (uintptr_t)0, bufsize, dm_mr->lkey, + (uintptr_t)dst, dst_mr->rkey, + IBV_SEND_SIGNALED | IBV_SEND_FENCE, 0); + GTEST_ASSERT_EQ(ibv_wr_complete(qp.qp_ex), 0); + wait_for_completion(); + + validate_buffers(); +} + +INSTANTIATE_TEST_SUITE_P(, test_tl_mlx5_dm, + ::testing::Values(1, 12, 31, 32, 8192, 8193, 32768, + 65536)); + +UCC_TEST_P(test_tl_mlx5_wait_on_data, waitOnDataWqe) { + uint64_t wait_on_value = std::get<0>(GetParam()); + uint64_t init_ctrl_value = std::get<1>(GetParam()); uint64_t buffer[3]; volatile uint64_t *ctrl, *src, *dst; int completions_num; @@ -156,6 +252,8 @@ UCC_TEST_F(test_tl_mlx5_wait_on_data, waitOnDataWqe) src = &buffer[1]; dst = &buffer[2]; + *ctrl = init_ctrl_value; + memset(&sg, 0, sizeof(sg)); sg.addr = (uintptr_t)src; sg.length = sizeof(uint64_t); @@ -172,16 +270,18 @@ UCC_TEST_F(test_tl_mlx5_wait_on_data, waitOnDataWqe) wr.wr.rdma.rkey = buffer_mr->rkey; // This work request is posted with wr_id = 1 - GTEST_ASSERT_EQ( - post_wait_on_data(qp.qp, 1, buffer_mr->lkey, (uintptr_t)ctrl, nullptr), - UCC_OK); + GTEST_ASSERT_EQ(post_wait_on_data(qp.qp, wait_on_value, buffer_mr->lkey, + (uintptr_t)ctrl, nullptr), + UCC_OK); // This work request is posted with wr_id = 0 GTEST_ASSERT_EQ(ibv_post_send(qp.qp, &wr, NULL), 0); + sleep(1); + *src = 0xdeadbeef; //memory barrier ucc_memory_cpu_fence(); - *ctrl = 1; + *ctrl = wait_on_value; while (1) { completions_num = ibv_poll_cq(cq, 1, wcs); @@ -200,6 +300,11 @@ UCC_TEST_F(test_tl_mlx5_wait_on_data, waitOnDataWqe) GTEST_ASSERT_EQ(ibv_dereg_mr(buffer_mr), UCC_OK); } +INSTANTIATE_TEST_SUITE_P( + , test_tl_mlx5_wait_on_data, + ::testing::Combine(::testing::Values(1, 1024, 1025, 0xF0F30F00, 0xFFFFFFFF), + ::testing::Values(0, 0xF0F30F01))); + UCC_TEST_P(test_tl_mlx5_umr_wqe, umrWqe) { uint16_t nbr_srcs = std::get<0>(GetParam()); @@ -310,3 +415,28 @@ INSTANTIATE_TEST_SUITE_P(, test_tl_mlx5_umr_wqe, ::testing::Values(5, 64), ::testing::Values(1, 3, 16), ::testing::Values(0, 7))); + +UCC_TEST_P(test_tl_mlx5_dm_alloc_reg, DeviceMemoryAllocation) +{ + size_t buf_size = std::get<0>(GetParam()); + size_t buf_num = std::get<1>(GetParam()); + struct ibv_dm *ptr = nullptr; + struct ibv_mr *mr = nullptr; + ucc_status_t status; + + status = dm_alloc_reg(ctx, pd, 0, buf_size, &buf_num, &ptr, &mr, &lib); + if (status == UCC_ERR_NO_MEMORY || status == UCC_ERR_NO_RESOURCE) { + GTEST_SKIP() << "cannot allocate " << buf_num << " chunk(s) of size " + << buf_size << " in device memory"; + } + GTEST_ASSERT_EQ(status, UCC_OK); + + ibv_dereg_mr(mr); + ibv_free_dm(ptr); +} + +INSTANTIATE_TEST_SUITE_P( + , test_tl_mlx5_dm_alloc_reg, + ::testing::Combine(::testing::Values(1, 2, 1024, 8191, 8192, 8193, 32768, + 65536, 262144), + ::testing::Values(UCC_ULUNITS_AUTO, 1, 3, 8))); diff --git a/test/gtest/tl/mlx5/test_tl_mlx5_wqe.h b/test/gtest/tl/mlx5/test_tl_mlx5_wqe.h index 28e13e2c60..e66b8dc2cf 100644 --- a/test/gtest/tl/mlx5/test_tl_mlx5_wqe.h +++ b/test/gtest/tl/mlx5/test_tl_mlx5_wqe.h @@ -7,6 +7,8 @@ #include "test_tl_mlx5_qps.h" #include "components/tl/mlx5/tl_mlx5_wqe.h" +#define DT uint8_t + typedef ucc_status_t (*ucc_tl_mlx5_post_rdma_fn_t)( struct ibv_qp *qp, uint32_t qpn, struct ibv_ah *ah, uintptr_t src_mkey_addr, size_t len, uint32_t src_mr_lkey, uintptr_t dst_addr, uint32_t dst_mr_key, @@ -27,12 +29,23 @@ typedef ucc_status_t (*ucc_tl_mlx5_post_umr_fn_t)( uint32_t repeat_count, uint16_t num_entries, struct mlx5dv_mr_interleaved *data, uint32_t ptr_mkey, void *ptr_address); +typedef ucc_status_t (*ucc_tl_mlx5_dm_alloc_reg_fn_t)( + struct ibv_context *ib_ctx, struct ibv_pd *pd, int dm_host, size_t buf_size, + size_t *buf_num_p, struct ibv_dm **ptr, struct ibv_mr **mr, + ucc_base_lib_t *lib); + // (msgsize) using RdmaWriteParams = int; +// (buf_size) +using DmParams = int; // (nrows, ncols, element_size) using TransposeParams = std::tuple; // (nbr_srcs, bytes_count, repeat_count, bytes_skip) using UmrParams = std::tuple; +// (buffer_size, buffer_nums) +using AllocDmParams = std::tuple; +// (wait_on_value, init_ctrl_value) +using WaitOnDataParams = std::tuple; class test_tl_mlx5_wqe : public test_tl_mlx5_rc_qp { public: @@ -67,19 +80,135 @@ class test_tl_mlx5_wqe : public test_tl_mlx5_rc_qp { } }; -class test_tl_mlx5_rdma_write - : public test_tl_mlx5_wqe, - public ::testing::WithParamInterface { -}; - class test_tl_mlx5_transpose : public test_tl_mlx5_wqe, public ::testing::WithParamInterface { }; -class test_tl_mlx5_wait_on_data : public test_tl_mlx5_wqe { +class test_tl_mlx5_wait_on_data + : public test_tl_mlx5_wqe, + public ::testing::WithParamInterface { }; class test_tl_mlx5_umr_wqe : public test_tl_mlx5_wqe, public ::testing::WithParamInterface { }; + +class test_tl_mlx5_rdma_write + : public test_tl_mlx5_wqe, + public ::testing::WithParamInterface { + public: + int bufsize; + DT * src, *dst; + struct ibv_mr *src_mr, *dst_mr; + + void buffers_init() + { + src = (DT *)malloc(bufsize); + GTEST_ASSERT_NE(src, nullptr); + dst = (DT *)malloc(bufsize); + GTEST_ASSERT_NE(dst, nullptr); + + for (int i = 0; i < bufsize; i++) { + src[i] = i % 256; + dst[i] = 0; + } + + src_mr = ibv_reg_mr(pd, src, bufsize, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + GTEST_ASSERT_NE(nullptr, src_mr); + dst_mr = ibv_reg_mr(pd, dst, bufsize, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + GTEST_ASSERT_NE(nullptr, dst_mr); + } + + void wait_for_completion() + { + int completions_num = 0; + struct ibv_wc wcs[1]; + + while (!completions_num) { + completions_num = ibv_poll_cq(cq, 1, wcs); + } + + GTEST_ASSERT_EQ(completions_num, 1); + GTEST_ASSERT_EQ(wcs[0].status, IBV_WC_SUCCESS); + } + + void validate_buffers() + { + for (int i = 0; i < bufsize; i++) { + GTEST_ASSERT_EQ(src[i], dst[i]); + } + } + + void TearDown() + { + GTEST_ASSERT_EQ(ibv_dereg_mr(src_mr), UCC_OK); + GTEST_ASSERT_EQ(ibv_dereg_mr(dst_mr), UCC_OK); + free(src); + free(dst); + } +}; + +class test_tl_mlx5_dm : public test_tl_mlx5_rdma_write { + public: + struct ibv_dm * dm_ptr; + struct ibv_alloc_dm_attr dm_attr; + struct ibv_mr * dm_mr; + + void buffers_init() + { + test_tl_mlx5_rdma_write::buffers_init(); + + struct ibv_device_attr_ex attr; + memset(&attr, 0, sizeof(attr)); + GTEST_ASSERT_EQ(ibv_query_device_ex(ctx, NULL, &attr), 0); + if (attr.max_dm_size < bufsize) { + if (!attr.max_dm_size) { + GTEST_SKIP() << "device doesn't support dm allocation"; + } else { + GTEST_SKIP() << "the requested buffer size (=" << bufsize + << ") for device memory should be less than " + << attr.max_dm_size; + } + } + + memset(&dm_attr, 0, sizeof(dm_attr)); + dm_attr.length = bufsize; + dm_ptr = ibv_alloc_dm(ctx, &dm_attr); + ASSERT_TRUE(dm_ptr != NULL); + ASSERT_TRUE(errno != 0); + + dm_mr = ibv_reg_dm_mr(pd, dm_ptr, 0, dm_attr.length, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_ZERO_BASED); + GTEST_ASSERT_NE(dm_mr, nullptr); + } + + void TearDown() + { + if (dm_mr) { + ibv_dereg_mr(dm_mr); + } + if (dm_ptr) { + ibv_free_dm(dm_ptr); + } + test_tl_mlx5_rdma_write::TearDown(); + } +}; + +class test_tl_mlx5_dm_alloc_reg + : public test_tl_mlx5_wqe, + public ::testing::WithParamInterface { + public: + ucc_tl_mlx5_dm_alloc_reg_fn_t dm_alloc_reg; + void SetUp() + { + test_tl_mlx5_wqe::SetUp(); + + dm_alloc_reg = (ucc_tl_mlx5_dm_alloc_reg_fn_t)dlsym( + tl_mlx5_so_handle, "ucc_tl_mlx5_dm_alloc_reg"); + ASSERT_EQ(nullptr, dlerror()); + } +};