From e87d7c872783d665e36d326700e48ea0cc16c686 Mon Sep 17 00:00:00 2001 From: nsarkauskas Date: Wed, 4 Dec 2024 17:14:54 -0800 Subject: [PATCH] dont memcpy if src==dst --- src/components/cl/hier/allgatherv/unpack.c | 32 +++++++++++++--------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/components/cl/hier/allgatherv/unpack.c b/src/components/cl/hier/allgatherv/unpack.c index 5ef5bb2028..490cb0ba3d 100644 --- a/src/components/cl/hier/allgatherv/unpack.c +++ b/src/components/cl/hier/allgatherv/unpack.c @@ -20,16 +20,16 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_finalize(ucc_coll_task_t *task) void ucc_cl_hier_allgatherv_unpack_progress(ucc_coll_task_t *task) { ucc_schedule_t *schedule = ucc_derived_of(task, ucc_schedule_t); - ucc_cl_hier_team_t *cl_team = ucc_derived_of(task->team, - ucc_cl_hier_team_t); - ucc_rank_t team_size = UCC_CL_TEAM_SIZE(cl_team); ucc_cl_hier_schedule_t *cl_schedule = ucc_derived_of(schedule, ucc_cl_hier_schedule_t); - ucc_ee_executor_task_t **tasks = cl_schedule->scratch->addr; + ucc_rank_t *n_tasks = cl_schedule->scratch->addr; + ucc_ee_executor_task_t **tasks = PTR_OFFSET( + cl_schedule->scratch->addr, + sizeof(ucc_rank_t)); ucc_status_t st = UCC_OK; ucc_rank_t i; - for (i = 0; i < team_size; i++) { + for (i = 0; i < *n_tasks; i++) { ucc_ee_executor_task_t *etask = tasks[i]; if (etask != NULL) { st = ucc_ee_executor_task_test(etask); @@ -61,8 +61,10 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_start(ucc_coll_task_t *task) ucc_ee_executor_task_args_t eargs = {0}; ucc_cl_hier_schedule_t *cl_schedule = ucc_derived_of(schedule, ucc_cl_hier_schedule_t); - ucc_ee_executor_task_t **tasks = cl_schedule->scratch->addr; - ucc_rank_t n_tasks = 0; + ucc_rank_t *n_tasks = cl_schedule->scratch->addr; + ucc_ee_executor_task_t **tasks = PTR_OFFSET( + cl_schedule->scratch->addr, + sizeof(ucc_rank_t)); size_t dt_size = ucc_dt_size( args->dst.info_v.datatype); ucc_ee_executor_t *exec; @@ -79,6 +81,7 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_start(ucc_coll_task_t *task) disp_counter = ucc_coll_args_get_total_count(args, args->dst.info_v.counts, team_size); + *n_tasks = 0; for (i = team_size - 1; i >= 0; i--) { this_rank_count = ucc_coll_args_get_count(args, args->dst.info_v.counts, @@ -92,10 +95,12 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_start(ucc_coll_task_t *task) args, args->dst.info_v.displacements, i) * dt_size); eargs.copy.len = this_rank_count * dt_size; - UCC_CHECK_GOTO( - ucc_ee_executor_task_post(exec, &eargs, &tasks[n_tasks]), - out, status); - n_tasks++; + if (eargs.copy.src != eargs.copy.dst) { + UCC_CHECK_GOTO( + ucc_ee_executor_task_post(exec, &eargs, &tasks[*n_tasks]), + out, status); + (*n_tasks)++; + } } schedule->super.status = UCC_INPROGRESS; @@ -129,7 +134,8 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_init(ucc_base_coll_args_t *coll_args, UCC_CHECK_GOTO( ucc_schedule_init(schedule, coll_args, team), free_schedule, status); - scratch_size = team_size * sizeof(ucc_ee_executor_task_t*); + /* Holds n_tasks and n_tasks # of ucc_ee_executor_task_t pointers */ + scratch_size = sizeof(ucc_rank_t) + team_size * sizeof(ucc_ee_executor_task_t*); UCC_CHECK_GOTO( ucc_mc_alloc(&cl_schedule->scratch, scratch_size, UCC_MEMORY_TYPE_HOST), free_schedule, status); @@ -146,4 +152,4 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_init(ucc_base_coll_args_t *coll_args, free_schedule: ucc_cl_hier_put_schedule(schedule); return status; -} +} \ No newline at end of file