diff --git a/src/components/tl/ucp/tl_ucp_dpu_offload.c b/src/components/tl/ucp/tl_ucp_dpu_offload.c index ba9a37520c..b6a2a5ef12 100644 --- a/src/components/tl/ucp/tl_ucp_dpu_offload.c +++ b/src/components/tl/ucp/tl_ucp_dpu_offload.c @@ -201,6 +201,8 @@ ucc_tl_ucp_dpu_xgvmi_free_task(ucc_coll_task_t *coll_task) int inplace = UCC_IS_INPLACE(coll_task->bargs.args); ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(tl_team); + ucc_free(task->dpu_xgvmi.requests); + if (task->dpu_xgvmi.bufs) { if (!inplace) { if (task->dpu_xgvmi.bufs->src_ebuf->memh != NULL) { @@ -451,7 +453,7 @@ ucc_status_t ucc_tl_ucp_dpu_xgvmi_init(ucc_base_coll_args_t *coll_args, } status = ucc_tl_ucp_dpu_xgvmi_task_init(coll_args, team, - rdma_task); + rdma_task); if (status != UCC_OK) { tl_error(UCC_TL_TEAM_LIB(tl_team), "failed to init task: %s", ucc_status_string(status)); @@ -479,10 +481,14 @@ ucc_status_t ucc_tl_ucp_dpu_xgvmi_init(ucc_base_coll_args_t *coll_args, } rdma_task->dpu_xgvmi.requests = ucc_malloc(sizeof(ucs_status_ptr_t) * size); + if (rdma_task->dpu_xgvmi.requests == NULL) { + tl_error(UCC_TL_TEAM_LIB(tl_team), "failed to alloc requests"); + goto free_rdma_task; + } UCC_CHECK_GOTO(ucc_tl_ucp_allgather_ring_init(&bargs, team, &rdma_task->dpu_xgvmi.allgather_task), - free_rdma_task, status); + free_requests, status); status = ucc_tl_ucp_coll_init(&barrier_coll_args, team, &barrier_task); @@ -513,6 +519,9 @@ ucc_status_t ucc_tl_ucp_dpu_xgvmi_init(ucc_base_coll_args_t *coll_args, ucc_tl_ucp_coll_finalize(barrier_task); free_allgather_task: ucc_tl_ucp_coll_finalize(rdma_task->dpu_xgvmi.allgather_task); +free_requests: + ucc_free(rdma_task->dpu_xgvmi.requests); + rdma_task->dpu_xgvmi.requests = NULL; free_rdma_task: ucc_tl_ucp_dpu_xgvmi_free_task(&rdma_task->super); out: diff --git a/test/gtest/coll/test_allgather.cc b/test/gtest/coll/test_allgather.cc index c48bb8303d..0cd88a6263 100644 --- a/test/gtest/coll/test_allgather.cc +++ b/test/gtest/coll/test_allgather.cc @@ -7,6 +7,9 @@ #include "common/test_ucc.h" #include "utils/ucc_math.h" +// For linear xgvmi allgather +#include "test_allreduce_sliding_window.h" + using Param_0 = std::tuple; using Param_1 = std::tuple; using Param_2 = std::tuple; @@ -265,22 +268,57 @@ UCC_TEST_P(test_allgather_alg, alg) const gtest_ucc_inplace_t inplace = std::get<3>(GetParam()); int n_procs = 5; char tune[32]; - sprintf(tune, "allgather:@%s:inf", std::get<4>(GetParam()).c_str()); - ucc_job_env_t env = {{"UCC_CL_BASIC_TUNE", "inf"}, - {"UCC_TL_UCP_TUNE", tune}}; - UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL, env); - UccTeam_h team = job.create_team(n_procs); - UccCollCtxVec ctxs; + ucc_job_env_t env = {{"UCC_CL_BASIC_TUNE", "inf"}, + {"UCC_TL_UCP_TUNE", tune}}; +#ifdef HAVE_UCX + ucs_status_t ucs_st = UCS_OK; + test_ucp_info_t *ucp_info = NULL; + // ONESIDED in order to ucp_init inside of ucc with the RMA feature enabled + UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL_ONESIDED, env); +#else + UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL, env); +#endif + UccTeam_h team = job.create_team(n_procs); + UccCollCtxVec ctxs; set_inplace(inplace); SET_MEM_TYPE(mem_type); - data_init(n_procs, dtype, count, ctxs, false); + + if (!std::get<4>(GetParam()).compare("linear_xgvmi")) { + if (inplace == TEST_INPLACE || mem_type != UCC_MEMORY_TYPE_HOST) { + data_fini(ctxs); + GTEST_SKIP() << "linear xgvmi must be mt host and not in place"; + } +#ifdef HAVE_UCX + // Algorithm is linear_xgvmi, set up gwbi + ucs_st = setup_gwbi(n_procs, ctxs, &ucp_info, inplace == TEST_INPLACE); + if (ucs_st != UCS_OK) { + free_gwbi(n_procs, ctxs, ucp_info, inplace == TEST_INPLACE); + data_fini(ctxs); + if (ucs_st == UCS_ERR_UNSUPPORTED) { + GTEST_SKIP() << "Exported memory key not supported"; + } else { + GTEST_FAIL() << ucs_status_string(ucs_st); + } + } +#else + GTEST_SKIP() << "linear xgvmi not supported"; +#endif + } + UccReq req(team, ctxs); req.start(); req.wait(); EXPECT_EQ(true, data_validate(ctxs)); + +#ifdef HAVE_UCX + if (!std::get<4>(GetParam()).compare("linear_xgvmi")) { + free_gwbi(n_procs, ctxs, ucp_info, inplace == TEST_INPLACE); + } +#endif + data_fini(ctxs); } @@ -296,7 +334,7 @@ INSTANTIATE_TEST_CASE_P( #endif ::testing::Values(1,3,8192), // count ::testing::Values(TEST_INPLACE, TEST_NO_INPLACE), - ::testing::Values("knomial", "ring", "neighbor", "bruck", "sparbit")), + ::testing::Values("knomial", "ring", "neighbor", "bruck", "sparbit", "linear_xgvmi")), [](const testing::TestParamInfo& info) { std::string name; name += ucc_datatype_str(std::get<0>(info.param)); diff --git a/test/gtest/coll/test_alltoall.cc b/test/gtest/coll/test_alltoall.cc index fadce3ba5c..179cc2c67a 100644 --- a/test/gtest/coll/test_alltoall.cc +++ b/test/gtest/coll/test_alltoall.cc @@ -6,6 +6,9 @@ #include "common/test_ucc.h" #include "utils/ucc_math.h" +// For linear xgvmi alltoall +#include "test_allreduce_sliding_window.h" + using Param_0 = std::tuple; using Param_1 = std::tuple; @@ -332,3 +335,54 @@ INSTANTIATE_TEST_CASE_P( #endif ::testing::Values(/*TEST_INPLACE,*/ TEST_NO_INPLACE), ::testing::Values(1,3,8192))); // count + + +#ifdef HAVE_UCX +class test_alltoall_2 : public test_alltoall, + public ::testing::WithParamInterface {}; + +UCC_TEST_P(test_alltoall_2, linear_xgvmi) +{ + const int n_procs = std::get<0>(GetParam()); + const ucc_datatype_t dtype = std::get<1>(GetParam()); + ucc_memory_type_t mem_type = std::get<2>(GetParam()); + gtest_ucc_inplace_t inplace = std::get<3>(GetParam()); + const int count = std::get<4>(GetParam()); + ucc_job_env_t env = {{"UCC_TL_UCP_TUNE", "alltoall:0-inf:@linear_xgvmi"}}; + UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL_ONESIDED, env); + UccTeam_h team = job.create_team(n_procs); + ucs_status_t ucs_st = UCS_OK; + test_ucp_info_t *ucp_info = NULL; + std::vector reference_ranks; + UccCollCtxVec ctxs; + + this->set_inplace(inplace); + SET_MEM_TYPE(mem_type); + data_init(n_procs, dtype, count, ctxs, NULL, false); + ucs_st = setup_gwbi(n_procs, ctxs, &ucp_info, inplace == TEST_INPLACE); + if (ucs_st != UCS_OK) { + free_gwbi(n_procs, ctxs, ucp_info, inplace == TEST_INPLACE); + data_fini(ctxs); + if (ucs_st == UCS_ERR_UNSUPPORTED) { + GTEST_SKIP() << "Exported memory key not supported"; + } else { + GTEST_FAIL() << ucs_status_string(ucs_st); + } + } + UccReq req(team, ctxs); + req.start(); + req.wait(); + EXPECT_EQ(true, data_validate(ctxs)); + data_fini(ctxs); +} + +INSTANTIATE_TEST_CASE_P( + , test_alltoall_2, + ::testing::Combine( + ::testing::Values(2, 4, 6), + ::testing::Values(UCC_DT_INT16), + ::testing::Values(UCC_MEMORY_TYPE_HOST), + ::testing::Values(/*TEST_INPLACE,*/ TEST_NO_INPLACE), + ::testing::Values(1,3))); + +#endif