diff --git a/src/components/tl/nccl/tl_nccl_coll.c b/src/components/tl/nccl/tl_nccl_coll.c index fdb1ade4a0..8a225c268b 100644 --- a/src/components/tl/nccl/tl_nccl_coll.c +++ b/src/components/tl/nccl/tl_nccl_coll.c @@ -135,6 +135,7 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args, ucc_tl_nccl_context_t); ucc_tl_nccl_task_t *task; ucc_status_t status; + ucc_coll_progress_fn_t progress_fn; if (!ucc_coll_args_is_predefined_dt(&coll_args->args, team->params.rank)) { tl_error(team->context->lib, @@ -147,11 +148,13 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args, tl_error(team->context->lib, "failed to get task from mpool"); return UCC_ERR_NO_MEMORY; } + progress_fn = task->super.progress; ucc_coll_task_init(&task->super, coll_args, team); UCC_TL_NCCL_PROFILE_REQUEST_NEW(task, "tl_nccl_task", 0); task->super.finalize = ucc_tl_nccl_coll_finalize; task->super.triggered_post = ucc_tl_nccl_triggered_post; + task->super.progress = progress_fn; task->completed = NULL; if (nccl_ctx->cfg.sync_type == UCC_TL_NCCL_COMPLETION_SYNC_TYPE_EVENT) { status = ucc_ec_create_event(&task->completed, UCC_EE_CUDA_STREAM); diff --git a/src/core/ucc_coll.c b/src/core/ucc_coll.c index 783b60abce..8cf3658570 100644 --- a/src/core/ucc_coll.c +++ b/src/core/ucc_coll.c @@ -197,10 +197,9 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init, memcpy(&op_args.args, coll_args, sizeof(ucc_coll_args_t)); op_args.team = team; op_args.args.flags = 0; - UCC_COPY_PARAM_BY_FIELD(&op_args.args, coll_args, UCC_COLL_ARGS_FIELD_FLAGS, - flags); + UCC_COPY_PARAM_BY_FIELD(&op_args.args, coll_args, + UCC_COLL_ARGS_FIELD_FLAGS, flags); ucc_coll_task_init(task, &op_args, NULL); - *request = &task->super; goto print_trace; } } @@ -278,9 +277,9 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init, task->seq_num = team->seq_num++; ucc_assert(task->super.status == UCC_OPERATION_INITIALIZED); - *request = &task->super; print_trace: + *request = &task->super; if (ucc_global_config.coll_trace.log_level >= UCC_LOG_LEVEL_DIAG) { char coll_str[256]; ucc_coll_str(task, coll_str, sizeof(coll_str), @@ -329,6 +328,7 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_post, (request), ucc_status_t status; if (ucc_global_config.coll_trace.log_level >= UCC_LOG_LEVEL_DEBUG) { + /* team is NULL if task is a dummy task, e.g. collective of zero size */ if (task->team) { ucc_rank_t rank = task->team->params.team->rank; if (ucc_global_config.coll_trace.log_level == UCC_LOG_LEVEL_DEBUG) { diff --git a/src/schedule/ucc_schedule.c b/src/schedule/ucc_schedule.c index bc4702a999..777b9c6b39 100644 --- a/src/schedule/ucc_schedule.c +++ b/src/schedule/ucc_schedule.c @@ -98,7 +98,7 @@ ucc_status_t ucc_dummy_finalize(ucc_coll_task_t *task) void ucc_dummy_progress(ucc_coll_task_t *task) { /* this function should never be called */ - ucc_assert(0); + ucc_assert_always(0); } ucc_status_t ucc_coll_task_init(ucc_coll_task_t *task,