Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarka committed Jan 17, 2024
1 parent c02a76e commit 89b96ed
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 23 deletions.
9 changes: 2 additions & 7 deletions src/components/tl/ucp/allreduce/allreduce_sliding_window.c
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,6 @@ void ucc_tl_ucp_allreduce_sliding_window_progress(ucc_coll_task_t *coll_task)
int window;
int put_idx;

// nick
ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(ucc_derived_of(coll_task->team, ucc_tl_ucp_team_t));

if (barrier_task != NULL) {
// mark sliding window task complete once barrier finishes
if (barrier_task->super.status == UCC_OK) {
Expand Down Expand Up @@ -440,12 +437,13 @@ void ucc_tl_ucp_allreduce_sliding_window_progress(ucc_coll_task_t *coll_task)
}

ucp_worker_fence(task->allreduce_sliding_window.ucp_worker);
task->allreduce_sliding_window.put_requests[put_idx] =
task->allreduce_sliding_window.put_requests[put_idx] =
ucp_put_nbx(
task->allreduce_sliding_window.eps[dst_rank],
src_addr, data_size, (uint64_t)dst_addr,
task->allreduce_sliding_window.dst_rkeys[dst_rank],
&req_param);

pipe->posted_put++;
pipe->dst_rank = (dst_rank + 1) % host_team_size;
}
Expand Down Expand Up @@ -485,9 +483,6 @@ void ucc_tl_ucp_allreduce_sliding_window_progress(ucc_coll_task_t *coll_task)
}

ucp_worker_progress(task->allreduce_sliding_window.ucp_worker);

//nick
ucp_worker_progress(tl_ctx->worker.ucp_worker);
}

if (pipe->count_serviced == pipe->my_count) {
Expand Down
28 changes: 12 additions & 16 deletions src/components/tl/ucp/allreduce/allreduce_sliding_window_setup.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
static int
ucc_tl_ucp_allreduce_sliding_window_register(
ucp_context_h ucp_context,
ucc_tl_ucp_team_t *tl_team,
struct ucc_tl_ucp_allreduce_sw_export_buf *ebuf,
void *packed_memh)
{
Expand All @@ -26,15 +27,15 @@ ucc_tl_ucp_allreduce_sliding_window_register(

ucs_status = ucp_mem_map(ucp_context, &params, &ebuf->memh);
if (UCS_OK != ucs_status) {
printf("import using ucp_mem_map() returned error: %s\n",
tl_error(UCC_TL_TEAM_LIB(tl_team), "import using ucp_mem_map() returned error: %s\n",
ucs_status_string(ucs_status));
return 0;
}

ucs_status = ucp_rkey_pack(ucp_context, ebuf->memh, &ebuf->packed_key,
&ebuf->packed_key_len);
if (UCS_OK != ucs_status) {
printf("ucp_rkey_pack() returned error: %s\n",
tl_error(UCC_TL_TEAM_LIB(tl_team), "ucp_rkey_pack() returned error: %s\n",
ucs_status_string(ucs_status));
return 0;
}
Expand All @@ -50,9 +51,6 @@ ucc_tl_ucp_allreduce_sliding_window_task_init(ucc_base_coll_args_t *coll_args,
ucc_status_t status = UCC_OK;
void *src_buf = coll_args->args.src.info.buffer;
void *dst_buf = coll_args->args.dst.info.buffer;
/* size_t src_len = 0;
size_t dst_len = coll_args->args.dst.info.count *
ucc_dt_size(coll_args->args.dst.info.datatype);*/
ucc_rank_t team_size = (ucc_rank_t)team->params.size;
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(tl_team);
Expand Down Expand Up @@ -85,9 +83,6 @@ ucc_tl_ucp_allreduce_sliding_window_task_init(ucc_base_coll_args_t *coll_args,
team_size);
task->allreduce_sliding_window.src_rkeys = ucc_malloc(sizeof(ucp_rkey_h)
* team_size);
/*
src_len = coll_args->args.src.info.count *
ucc_dt_size(coll_args->args.src.info.datatype);*/
}

task->allreduce_sliding_window.rbufs = ucc_malloc(sizeof(void*)
Expand All @@ -97,14 +92,15 @@ ucc_tl_ucp_allreduce_sliding_window_task_init(ucc_base_coll_args_t *coll_args,
task->allreduce_sliding_window.eps = ucc_malloc(sizeof(ucp_ep_h)
* team_size);

task->allreduce_sliding_window.put_requests = task->allreduce_sliding_window.pipe->put_requests;
task->allreduce_sliding_window.put_requests =
task->allreduce_sliding_window.pipe->put_requests;


if (!task->allreduce_sliding_window.inplace)
if (!task->allreduce_sliding_window.inplace) {
task->allreduce_sliding_window.src_ebuf = ucc_malloc(
sizeof(struct ucc_tl_ucp_allreduce_sw_export_buf));
else
} else {
task->allreduce_sliding_window.src_ebuf = NULL;
}

task->allreduce_sliding_window.dst_ebuf = ucc_malloc(
sizeof(struct ucc_tl_ucp_allreduce_sw_export_buf));
Expand All @@ -116,13 +112,13 @@ ucc_tl_ucp_allreduce_sliding_window_task_init(ucc_base_coll_args_t *coll_args,

// Register the src and dst bufs
if (!task->allreduce_sliding_window.inplace) {
ucc_tl_ucp_allreduce_sliding_window_register(tl_ctx->worker.ucp_context, task->allreduce_sliding_window.src_ebuf, gwbi_p->packed_src_memh);
ucc_tl_ucp_allreduce_sliding_window_register(tl_ctx->worker.ucp_context, tl_team, task->allreduce_sliding_window.src_ebuf, gwbi_p->packed_src_memh);
memcpy(allgather_data->packed_src_key,
task->allreduce_sliding_window.src_ebuf->packed_key,
task->allreduce_sliding_window.src_ebuf->packed_key_len);
}

ucc_tl_ucp_allreduce_sliding_window_register(tl_ctx->worker.ucp_context, task->allreduce_sliding_window.dst_ebuf, gwbi_p->packed_dst_memh);
ucc_tl_ucp_allreduce_sliding_window_register(tl_ctx->worker.ucp_context, tl_team, task->allreduce_sliding_window.dst_ebuf, gwbi_p->packed_dst_memh);
memcpy(allgather_data->packed_dst_key,
task->allreduce_sliding_window.dst_ebuf->packed_key,
task->allreduce_sliding_window.dst_ebuf->packed_key_len);
Expand Down Expand Up @@ -169,7 +165,7 @@ ucc_tl_ucp_allreduce_sliding_window_allgather_info_finalize(ucc_service_coll_req
all_host_allgather[i].packed_dst_key,
&dst_unpacked);
if (UCS_OK != ucs_status) {
printf("dst rkey unpack failed\n");
tl_error(UCC_TL_TEAM_LIB(tl_team), "dst rkey unpack failed\n");
return UCC_ERR_NO_RESOURCE;
}

Expand All @@ -182,7 +178,7 @@ ucc_tl_ucp_allreduce_sliding_window_allgather_info_finalize(ucc_service_coll_req
all_host_allgather[i].packed_src_key,
&src_unpacked);
if (UCS_OK != ucs_status) {
printf("src rkey unpack failed\n");
tl_error(UCC_TL_TEAM_LIB(tl_team), "src rkey unpack failed\n");
return UCC_ERR_NO_RESOURCE;
}

Expand Down

0 comments on commit 89b96ed

Please sign in to comment.