diff --git a/test/gtest/coll/test_allreduce.cc b/test/gtest/coll/test_allreduce.cc index 957cda6ce1..b37be88974 100644 --- a/test/gtest/coll/test_allreduce.cc +++ b/test/gtest/coll/test_allreduce.cc @@ -466,20 +466,72 @@ void ep_err_cb(void *arg, ucp_ep_h ep, ucs_status_t ucs_status) ucs_status_string(ucs_status)); } +struct export_buf { + ucp_context_h ucp_context; + ucp_mem_h memh; + void *packed_memh; + size_t packed_memh_len; + void *packed_key; + size_t packed_key_len; + uint64_t memh_id; +}; + +int buffer_export_ucc(ucp_context_h ucp_context, void *buf, size_t len, + struct export_buf *ebuf) +{ + ucs_status_t ucs_status; + ucp_mem_map_params_t params; + ucp_memh_pack_params_t pack_params; + + ebuf->ucp_context = ucp_context; + + params.field_mask = + UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH; + params.address = buf; + params.length = len; + + ucs_status = ucp_mem_map(ucp_context, ¶ms, &ebuf->memh); + assert(ucs_status == UCS_OK); + + pack_params.field_mask = UCP_MEMH_PACK_PARAM_FIELD_FLAGS; + pack_params.flags = UCP_MEMH_PACK_FLAG_EXPORT; + + ucs_status = ucp_memh_pack(ebuf->memh, &pack_params, &ebuf->packed_memh, + &ebuf->packed_memh_len); + if (ucs_status != UCS_OK) { + printf("ucp_memh_pack() returned error: %s\n", + ucs_status_string(ucs_status)); + ebuf->packed_memh = NULL; + ebuf->packed_memh_len = 0; + } + ucs_status = ucp_rkey_pack(ucp_context, ebuf->memh, &ebuf->packed_key, + &ebuf->packed_key_len); + if (UCS_OK != ucs_status) { + printf("ucp_rkey_pack() returned error: %s\n", + ucs_status_string(ucs_status)); + return 0; + } + + return 0; +} + // nick test void setup_gwbi(int n_procs, UccCollCtxVec &ctxs) { - const int nthreads = 4; + const int nthreads = 1; const int num_get_bufs = 4; const int window_size = 8; const int buf_size = 16384; - int i, j; + int i, j, k; + typedef struct ucp_info { ucp_context_h ucp_ctx; ucp_worker_h *ucp_thread_workers; ucp_address_t **ucp_thread_worker_addrs; size_t *ucp_thread_worker_addr_lens; + struct export_buf src_ebuf; + struct export_buf dst_ebuf; } ucp_info_t; ucp_info_t *ucp_infos = (ucp_info_t*) ucc_malloc(sizeof(ucp_info_t) * n_procs); @@ -541,10 +593,12 @@ void setup_gwbi(int n_procs, UccCollCtxVec &ctxs) ucp_info_t ucp_info; ucp_init_ex(nthreads, &ucp_info.ucp_ctx, &ucp_info.ucp_thread_workers, &ucp_info.ucp_thread_worker_addrs, &ucp_info.ucp_thread_worker_addr_lens); memcpy(&ucp_infos[i], &ucp_info, sizeof(ucp_info_t)); + ((ucc_tl_ucp_allreduce_sw_global_work_buf_info*)(ctxs[i]->args->global_work_buffer))->ucp_thread_workers = ucp_infos[i].ucp_thread_workers; } // set up eps - for (auto ctx : ctxs) { + for (k = 0; k < n_procs; k++) { + auto ctx = ctxs[k]; ucc_tl_ucp_allreduce_sw_global_work_buf_info *gwbi = (ucc_tl_ucp_allreduce_sw_global_work_buf_info *) ctx->args->global_work_buffer; @@ -565,7 +619,7 @@ void setup_gwbi(int n_procs, UccCollCtxVec &ctxs) ep_params.address = (ucp_address_t*) ucp_addr; for(j = 0; j < nthreads; j++) { - ucs_status = ucp_ep_create(ucp_infos[i].ucp_thread_workers[j], &ep_params, &new_ep); + ucs_status = ucp_ep_create(ucp_infos[k].ucp_thread_workers[j], &ep_params, &new_ep); if (ucs_status != UCS_OK) { printf("ucp_ep_create() returned: %s\n", ucs_status_string(ucs_status)); @@ -575,13 +629,79 @@ void setup_gwbi(int n_procs, UccCollCtxVec &ctxs) } } } + + // set up sbufs, rbufs, src_rkeys, and dst_rkeys + for (i = 0; i < n_procs; i++) { + // my proc's gwbi + ucc_tl_ucp_allreduce_sw_global_work_buf_info *gwbi = + (ucc_tl_ucp_allreduce_sw_global_work_buf_info *) ctxs[i]->args->global_work_buffer; + // my proc's ucp_info + ucp_info_t *ucp_info = &ucp_infos[i]; + struct export_buf *src_ebuf = &ucp_info->src_ebuf; + struct export_buf *dst_ebuf = &ucp_info->dst_ebuf; + size_t src_len = ctxs[i]->args->src.info.count * ucc_dt_size(ctxs[i]->args->src.info.datatype); + size_t dst_len = ctxs[i]->args->dst.info.count * ucc_dt_size(ctxs[i]->args->dst.info.datatype); + + buffer_export_ucc(ucp_info->ucp_ctx, ctxs[i]->args->src.info.buffer, src_len, src_ebuf); + buffer_export_ucc(ucp_info->ucp_ctx, ctxs[i]->args->dst.info.buffer, dst_len, dst_ebuf); + + gwbi->sbufs = (void**) ucc_malloc(sizeof(void*) * n_procs); + gwbi->rbufs = (void**) ucc_malloc(sizeof(void*) * n_procs); + + gwbi->src_rkeys = (ucp_rkey_h *) ucc_malloc(sizeof(ucp_rkey_h) * n_procs); // * nthreads + gwbi->dst_rkeys = (ucp_rkey_h *) ucc_malloc(sizeof(ucp_rkey_h) * n_procs); // * nthreads + + ((void**) gwbi->sbufs)[i] = ctxs[i]->args->src.info.buffer; + ((void**) gwbi->rbufs)[i] = ctxs[i]->args->dst.info.buffer; + } + + // copy sbufs, rbufs, src_rkeys, and dst_rkeys + for (i = 0; i < n_procs; i++) { + + ucc_tl_ucp_allreduce_sw_global_work_buf_info *my_gwbi = + (ucc_tl_ucp_allreduce_sw_global_work_buf_info *) ctxs[i]->args->global_work_buffer; + + for (j = 0; j < n_procs; j++) { + + ucc_tl_ucp_allreduce_sw_global_work_buf_info *their_gwbi = + (ucc_tl_ucp_allreduce_sw_global_work_buf_info *) ctxs[j]->args->global_work_buffer; + ucs_status_t ucs_status = UCS_OK; + + ucs_status = ucp_ep_rkey_unpack(my_gwbi->host_eps[j], + ucp_infos[j].src_ebuf.packed_key, + &my_gwbi->src_rkeys[j]); + if (UCS_OK != ucs_status) { + printf("src rkey unpack failed\n"); + return; + } + ucs_status = ucp_ep_rkey_unpack(my_gwbi->host_eps[j], + ucp_infos[j].dst_ebuf.packed_key, + &my_gwbi->dst_rkeys[j]); + if (UCS_OK != ucs_status) { + printf("dst rkey unpack failed\n"); + return; + } + + ((void**) their_gwbi->sbufs)[i] = ((void**) my_gwbi->sbufs)[i]; + ((void**) their_gwbi->rbufs)[i] = ((void**) my_gwbi->rbufs)[i]; + } + } + + // set the flag that indicates the global work buffer was passed + for (auto ctx : ctxs) { + ctx->args->mask |= UCC_COLL_ARGS_FIELD_FLAGS | + UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER; + ctx->args->flags |= UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS; + } } // nick test TYPED_TEST(test_allreduce_alg, sliding_window) { int n_procs = 15; - ucc_job_env_t env = {{"UCC_CL_HIER_TUNE", "allreduce:@rab:0-inf:inf"}, +/* ucc_job_env_t env = {{"UCC_CL_HIER_TUNE", "allreduce:@rab:0-inf:inf"}, {"UCC_CL_HIER_ALLREDUCE_RAB_PIPELINE", "thresh=1024:nfrags=11"}, + {"UCC_CLS", "all"}};*/ + ucc_job_env_t env = {{"UCC_TL_UCP_TUNE", "allreduce:@2"}, {"UCC_CLS", "all"}}; UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL, env); UccTeam_h team = job.create_team(n_procs); @@ -602,10 +722,12 @@ TYPED_TEST(test_allreduce_alg, sliding_window) { SET_MEM_TYPE(m); this->set_inplace(inplace); this->data_init(n_procs, TypeParam::dt, count, ctxs, true); - UccReq req(team, ctxs); setup_gwbi(n_procs, ctxs); + UccReq req(team, ctxs); + repeat = 1; + for (auto i = 0; i < repeat; i++) { req.start(); req.wait();