Skip to content

Commit

Permalink
TL/UCP: make local copy nb in allgather
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Jan 24, 2024
1 parent 75ecf74 commit 02147dd
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 12 deletions.
53 changes: 46 additions & 7 deletions src/components/tl/ucp/allgather/allgather_ring.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
size_t count = TASK_ARGS(task).dst.info.count;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_status_t status = UCC_OK;
ucc_rank_t sendto, recvfrom, sblock, rblock;
int step;
void *buf;
Expand Down Expand Up @@ -69,7 +70,14 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
}
}
ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));
task->super.status = UCC_OK;
if (task->allgather_ring.etask) {
status = ucc_ee_executor_task_test(task->allgather_ring.etask);
if (status == UCC_INPROGRESS) {
return;
}
ucc_ee_executor_task_finalize(task->allgather_ring.etask);
}
task->super.status = status;
out:
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_ring_done", 0);
}
Expand All @@ -88,22 +96,50 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task)
ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num;
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_status_t status;
ucc_rank_t block;
ucc_rank_t sendto, recvfrom, sblock, rblock;
ucc_ee_executor_t *exec;
ucc_ee_executor_task_args_t eargs;
void *buf;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_ring_start", 0);
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize);
recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize);
sblock = task->allgather_ring.get_send_block(&task->subset, trank, tsize, 0);
rblock = task->allgather_ring.get_recv_block(&task->subset, trank, tsize, 0);
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
block = task->allgather_ring.get_send_block(&task->subset, trank, tsize,
0);
status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * block),
sbuf, data_size, rmem, smem);
if (ucc_unlikely(UCC_OK != status)) {
status = ucc_coll_task_get_executor(&task->super, &exec);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}

eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.src = sbuf;
eargs.copy.dst = PTR_OFFSET(rbuf, data_size * sblock);
eargs.copy.len = data_size;

status = ucc_ee_executor_task_post(exec, &eargs,
&task->allgather_ring.etask);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}
buf = sbuf;
} else {
task->allgather_ring.etask = NULL;
buf = PTR_OFFSET(rbuf, data_size * sblock);
}

UCPCHECK_GOTO(ucc_tl_ucp_send_nb(buf, data_size, smem, sendto, team, task),
task, out);
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, rblock * data_size),
data_size, rmem, recvfrom, team, task),
task, out);

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);

out:
return status;
}

ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task)
Expand All @@ -128,6 +164,9 @@ ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task)
task->allgather_ring.get_recv_block = ucc_tl_ucp_allgather_ring_get_recv_block;
task->super.post = ucc_tl_ucp_allgather_ring_start;
task->super.progress = ucc_tl_ucp_allgather_ring_progress;
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
}

return UCC_OK;
}
Expand Down
1 change: 1 addition & 0 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ typedef struct ucc_tl_ucp_task {
ucc_rank_t trank,
ucc_rank_t tsize,
int step);
ucc_ee_executor_task_t *etask;
} allgather_ring;
struct {
ucc_rank_t dist;
Expand Down
20 changes: 15 additions & 5 deletions src/components/tl/ucp/tl_ucp_service_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ static ucc_status_t ucc_tl_ucp_service_coll_stop_executor(ucc_coll_task_t *task)
ucc_status_t ucc_tl_ucp_service_allreduce(ucc_base_team_t *team, void *sbuf,
void *rbuf, ucc_datatype_t dt,
size_t count, ucc_reduction_op_t op,
ucc_subset_t subset,
ucc_subset_t subset,
ucc_coll_task_t **task_p)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
Expand Down Expand Up @@ -140,7 +140,7 @@ ucc_status_t ucc_tl_ucp_service_allreduce(ucc_base_team_t *team, void *sbuf,

ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf,
void *rbuf, size_t msgsize,
ucc_subset_t subset,
ucc_subset_t subset,
ucc_coll_task_t **task_p)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
Expand Down Expand Up @@ -178,6 +178,14 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf,
task->n_polls = npolls;
task->super.progress = ucc_tl_ucp_allgather_ring_progress;
task->super.finalize = ucc_tl_ucp_coll_finalize;
if (in_place) {
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
}

status = ucc_tl_ucp_service_coll_start_executor(&task->super);
if (status != UCC_OK) {
goto free_task;
}

status = ucc_tl_ucp_allgather_ring_start(&task->super);
if (status != UCC_OK) {
Expand All @@ -187,15 +195,16 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf,
*task_p = &task->super;
return status;
finalize_coll:
ucc_tl_ucp_coll_finalize(*task_p);
ucc_tl_ucp_coll_finalize(&task->super);
ucc_tl_ucp_service_coll_stop_executor(&task->super);
free_task:
ucc_tl_ucp_put_task(task);
return status;
}

ucc_status_t ucc_tl_ucp_service_bcast(ucc_base_team_t *team, void *buf,
size_t msgsize, ucc_rank_t root,
ucc_subset_t subset,
ucc_subset_t subset,
ucc_coll_task_t **task_p)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
Expand Down Expand Up @@ -239,7 +248,8 @@ ucc_status_t ucc_tl_ucp_service_bcast(ucc_base_team_t *team, void *buf,
return status;
}

void ucc_tl_ucp_service_update_id(ucc_base_team_t *team, uint16_t id) {
void ucc_tl_ucp_service_update_id(ucc_base_team_t *team, uint16_t id)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);

tl_team->super.super.params.id = id;
Expand Down

0 comments on commit 02147dd

Please sign in to comment.