Skip to content

Commit

Permalink
dont memcpy if src==dst
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarkauskas committed Dec 5, 2024
1 parent 8672c6d commit e87d7c8
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions src/components/cl/hier/allgatherv/unpack.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_finalize(ucc_coll_task_t *task)
void ucc_cl_hier_allgatherv_unpack_progress(ucc_coll_task_t *task)
{
ucc_schedule_t *schedule = ucc_derived_of(task, ucc_schedule_t);
ucc_cl_hier_team_t *cl_team = ucc_derived_of(task->team,
ucc_cl_hier_team_t);
ucc_rank_t team_size = UCC_CL_TEAM_SIZE(cl_team);
ucc_cl_hier_schedule_t *cl_schedule = ucc_derived_of(schedule,
ucc_cl_hier_schedule_t);
ucc_ee_executor_task_t **tasks = cl_schedule->scratch->addr;
ucc_rank_t *n_tasks = cl_schedule->scratch->addr;
ucc_ee_executor_task_t **tasks = PTR_OFFSET(
cl_schedule->scratch->addr,
sizeof(ucc_rank_t));
ucc_status_t st = UCC_OK;
ucc_rank_t i;

for (i = 0; i < team_size; i++) {
for (i = 0; i < *n_tasks; i++) {
ucc_ee_executor_task_t *etask = tasks[i];
if (etask != NULL) {
st = ucc_ee_executor_task_test(etask);
Expand Down Expand Up @@ -61,8 +61,10 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_start(ucc_coll_task_t *task)
ucc_ee_executor_task_args_t eargs = {0};
ucc_cl_hier_schedule_t *cl_schedule = ucc_derived_of(schedule,
ucc_cl_hier_schedule_t);
ucc_ee_executor_task_t **tasks = cl_schedule->scratch->addr;
ucc_rank_t n_tasks = 0;
ucc_rank_t *n_tasks = cl_schedule->scratch->addr;
ucc_ee_executor_task_t **tasks = PTR_OFFSET(
cl_schedule->scratch->addr,
sizeof(ucc_rank_t));
size_t dt_size = ucc_dt_size(
args->dst.info_v.datatype);
ucc_ee_executor_t *exec;
Expand All @@ -79,6 +81,7 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_start(ucc_coll_task_t *task)
disp_counter = ucc_coll_args_get_total_count(args,
args->dst.info_v.counts,
team_size);
*n_tasks = 0;

for (i = team_size - 1; i >= 0; i--) {
this_rank_count = ucc_coll_args_get_count(args, args->dst.info_v.counts,
Expand All @@ -92,10 +95,12 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_start(ucc_coll_task_t *task)
args, args->dst.info_v.displacements, i) *
dt_size);
eargs.copy.len = this_rank_count * dt_size;
UCC_CHECK_GOTO(
ucc_ee_executor_task_post(exec, &eargs, &tasks[n_tasks]),
out, status);
n_tasks++;
if (eargs.copy.src != eargs.copy.dst) {
UCC_CHECK_GOTO(
ucc_ee_executor_task_post(exec, &eargs, &tasks[*n_tasks]),
out, status);
(*n_tasks)++;
}
}

schedule->super.status = UCC_INPROGRESS;
Expand Down Expand Up @@ -129,7 +134,8 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_init(ucc_base_coll_args_t *coll_args,
UCC_CHECK_GOTO(
ucc_schedule_init(schedule, coll_args, team), free_schedule, status);

scratch_size = team_size * sizeof(ucc_ee_executor_task_t*);
/* Holds n_tasks and n_tasks # of ucc_ee_executor_task_t pointers */
scratch_size = sizeof(ucc_rank_t) + team_size * sizeof(ucc_ee_executor_task_t*);
UCC_CHECK_GOTO(
ucc_mc_alloc(&cl_schedule->scratch, scratch_size, UCC_MEMORY_TYPE_HOST),
free_schedule, status);
Expand All @@ -146,4 +152,4 @@ ucc_status_t ucc_cl_hier_allgatherv_unpack_init(ucc_base_coll_args_t *coll_args,
free_schedule:
ucc_cl_hier_put_schedule(schedule);
return status;
}
}

0 comments on commit e87d7c8

Please sign in to comment.