diff --git a/src/components/tl/cuda/bcast/bcast_linear.c b/src/components/tl/cuda/bcast/bcast_linear.c index bd9ad3786a..d7fa22b1c9 100644 --- a/src/components/tl/cuda/bcast/bcast_linear.c +++ b/src/components/tl/cuda/bcast/bcast_linear.c @@ -8,20 +8,21 @@ enum { // Barrier setup stages - STAGE_INIT_BAR_ROOT, // Initial stage for the root rank to identify and claim a free barrier - STAGE_FIND_BAR_PEER, // Stage where peer ranks wait while the root rank identifies a free barrier + STAGE_INIT_BAR_ROOT, // Initial stage for the root rank to identify and claim a free barrier + STAGE_FIND_BAR_PEER, // Stage where peer ranks wait while the root rank identifies a free barrier - STAGE_SYNC, // Initialize the barrier and synchronize the segment required for the current task - STAGE_SETUP, // Verify that all ranks are aligned and have reached the barrier + STAGE_SYNC, // Initialize the barrier and synchronize the segment required for the current task + STAGE_SETUP, // Verify that all ranks are aligned and have reached the barrier // Stages specific to the root rank - STAGE_COPY, // Post copy task: copy data block from src to a scratch buffer - STAGE_WAIT_COPY, // The root waits for the completion of its copy operation - STAGE_WAIT_ALL, // The root rank waits until all other ranks have reached the same operational step - STAGE_WAIT_COMPLETION, // The root rank waits for all other ranks to complete the broadcast operation + STAGE_COPY, // Post copy task: copy data block from src to a scratch buffer + STAGE_WAIT_COPY, // The root waits for the completion of its copy operation + STAGE_WAIT_ALL, // The root rank waits until all other ranks have reached the same operational step + STAGE_WAIT_COMPLETION, // The root rank waits for all other ranks to complete the broadcast operation // non-root - STAGE_WAIT_ROOT, // Non-root ranks wait while the root rank writes data to its scratch buffer - STAGE_CLIENT_COPY, // Non-root ranks initiate their own copy tasks after the root's operations - STAGE_CLIENT_COPY_WAIT, // Non-root ranks wait for the completion of the copy operation from the root's scratch buffer + STAGE_WAIT_ROOT, // Wait while the root rank writes data to its scratch buffer + STAGE_CLIENT_COPY, // Initiate their own copy tasks after the root's operations + STAGE_CLIENT_COPY_WAIT, // Wait for the completion of the copy operation from the root's scratch buffer + STAGE_CLIENT_WAIT_COMPLETION, // Wait for the completion of algorithm on all ranks, global sync with root }; static inline ucc_status_t ucc_tl_cuda_bcast_linear_setup_start(ucc_tl_cuda_task_t *task) @@ -149,19 +150,21 @@ static void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) ucc_tl_cuda_team_t *team = TASK_TEAM(task); ucc_rank_t trank = UCC_TL_TEAM_RANK(team); size_t half_scratch_size = get_raw_scratch_size(team) / 2; - ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; + ucc_rank_t tsize = UCC_COLL_ARGS_ACTIVE_SET(&TASK_ARGS(task)) + ? (ucc_rank_t)task->subset.map.ep_num + : UCC_TL_TEAM_SIZE(team); size_t chunk_size = - task->bcast_linear.step < task->bcast_linear.num_steps - ? ucc_min(half_scratch_size, task->bcast_linear.size) - : task->bcast_linear.size - - (task->bcast_linear.step - 1) * half_scratch_size; + task->bcast_linear.step < task->bcast_linear.num_steps + ? ucc_min(half_scratch_size, task->bcast_linear.size) + : task->bcast_linear.size - + (task->bcast_linear.step - 1) * half_scratch_size; size_t offset_buff = task->bcast_linear.step * half_scratch_size; ucc_ee_executor_t *exec; ucc_ee_executor_task_t *etask; - ucc_status_t st; void *sbuf, *dbuf; - int i; ucc_rank_t peer; + ucc_status_t st; + int i; task->super.status = UCC_INPROGRESS; @@ -174,32 +177,23 @@ static void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) switch (task->bcast_linear.stage) { case STAGE_INIT_BAR_ROOT: st = root_find_free_barrier(task); - if (st == UCC_ERR_NOT_FOUND) { - // no free barriers found, try next time - return; - } if (st == UCC_OK) { task->bcast_linear.stage = STAGE_SYNC; - return; - - } else { - task->super.status = UCC_ERR_NO_RESOURCE; - return; + } else if (st != UCC_ERR_NOT_FOUND) { + task->super.status = st; } + // no free barriers found, try next time + return; case STAGE_FIND_BAR_PEER: st = peer_find_free_barrier(task); - if (st == UCC_ERR_NOT_FOUND) { - // no free barriers found, wait for root - return; - } if (st == UCC_OK) { // barrier found, continue to next stages task->bcast_linear.stage = STAGE_SYNC; - return; - } else { - task->super.status = UCC_ERR_NO_RESOURCE; - return; + } else if (st != UCC_ERR_NOT_FOUND) { + task->super.status = st; } + // no free barriers found by root, try next time + return; case STAGE_SYNC: if (ucc_tl_cuda_get_sync_root(task, task->bcast_linear.root) != UCC_OK) { return; @@ -215,8 +209,7 @@ static void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) st = ucc_tl_cuda_bcast_linear_setup_test(task); if (st != UCC_OK) { task->super.status = st; - return; - } + return;the copy operation from the root's scratch buffer if (trank == task->bcast_linear.root) { task->bcast_linear.stage = STAGE_COPY; } else { @@ -264,7 +257,7 @@ static void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) return; } case STAGE_WAIT_ALL: - for (i = 0; i < tsize; ++i) { + for (i = 0; i < tsize; ++i) {the copy operation from the root's scratch buffer if (UCC_COLL_ARGS_ACTIVE_SET(&TASK_ARGS(task))) { // eval phys rank from virt peer = ucc_ep_map_eval(task->subset.map, i); @@ -347,19 +340,26 @@ static void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) task->bcast_linear.stage = STAGE_WAIT_ROOT; return; } else { - // Done - // signal barrier to notify root + // start barrier to sync with root + task->bcast_linear.stage = STAGE_CLIENT_WAIT_COMPLETION; st = ucc_tl_cuda_shm_barrier_start(trank, task->bar); if (ucc_unlikely(st != UCC_OK)) { ucc_error("failed to start barrier from peer rank"); task->super.status = st; return; } - task->super.status = UCC_OK; - break; } } } + break; + case STAGE_CLIENT_WAIT_COMPLETION: + st = ucc_tl_cuda_shm_barrier_test(trank, task->bar); + if (st != UCC_OK) { + // someone still working, lets check next time + task->super.status = st; + return; + } + task->super.status = UCC_OK; default: break; }