Skip to content

Commit

Permalink
CORE: Asymmetric copy-in for scatter(v)
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarka committed Aug 22, 2024
1 parent f201563 commit cbf46d8
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 43 deletions.
18 changes: 18 additions & 0 deletions src/core/ucc_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,17 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_post, (request),
}
}

if (task->bargs.asymmetric_save_info.scratch != NULL &&
(task->bargs.args.coll_type == UCC_COLL_TYPE_SCATTER ||
task->bargs.args.coll_type == UCC_COLL_TYPE_SCATTERV)) {
status = ucc_copy_asymmetric_buffer(task);
if (status != UCC_OK) {
ucc_error("failure copying in asymmetric buffer: %s",
ucc_status_string(status));
return status;
}
}

COLL_POST_STATUS_CHECK(task);
if (UCC_COLL_TIMEOUT_REQUIRED(task)) {
task->start_time = ucc_get_time();
Expand Down Expand Up @@ -416,6 +427,13 @@ ucc_status_t ucc_collective_finalize_internal(ucc_coll_task_t *task)
return UCC_ERR_INVALID_PARAM;
}

if (task->bargs.asymmetric_save_info.scratch) {
st = ucc_coll_args_free_asymmetric_buffer(task);
if (ucc_unlikely(st != UCC_OK)) {
ucc_error("error freeing asymmetric buf: %s", ucc_status_string(st));
}
}

if (task->executor) {
st = ucc_ee_executor_finalize(task->executor);
if (ucc_unlikely(st != UCC_OK)) {
Expand Down
42 changes: 4 additions & 38 deletions src/schedule/ucc_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,36 +166,6 @@ ucc_status_t ucc_dependency_handler(ucc_coll_task_t *parent,
ucc_status_t ucc_triggered_post(ucc_ee_h ee, ucc_ev_t *ev,
ucc_coll_task_t *task);

static inline
ucc_status_t ucc_copy_asymmetric_buffer_out(ucc_coll_task_t *task)
{
ucc_status_t status = UCC_OK;
ucc_coll_args_t *coll_args = &task->bargs.args;
ucc_buffer_info_asymmetric_memtype_t *save = &task->bargs.asymmetric_save_info;
ucc_rank_t size = task->team->params.size;

if(task->bargs.args.coll_type == UCC_COLL_TYPE_GATHERV) {
status = ucc_mc_memcpy(save->old_asymmetric_buffer.info_v.buffer,
save->scratch->addr,
ucc_coll_args_get_total_count(coll_args,
coll_args->dst.info_v.counts, size),
save->old_asymmetric_buffer.info_v.mem_type,
save->scratch->mt);
} else {
status = ucc_mc_memcpy(save->old_asymmetric_buffer.info.buffer,
save->scratch->addr,
ucc_dt_size(save->old_asymmetric_buffer.info.datatype) *
save->old_asymmetric_buffer.info.count,
save->old_asymmetric_buffer.info.mem_type,
save->scratch->mt);
}
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("error copying back to old asymmetric buffer: %s",
ucc_status_string(status));
}
return status;
}

static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task)
{
ucc_status_t status = task->status;
Expand All @@ -217,18 +187,14 @@ static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task)

if (ucc_likely(status == UCC_OK)) {
ucc_buffer_info_asymmetric_memtype_t *save = &task->bargs.asymmetric_save_info;
if (save->scratch != NULL) {
status = ucc_copy_asymmetric_buffer_out(task);
if (save->scratch &&
task->bargs.args.coll_type != UCC_COLL_TYPE_SCATTERV &&
task->bargs.args.coll_type != UCC_COLL_TYPE_SCATTER) {
status = ucc_copy_asymmetric_buffer(task);
if (status != UCC_OK) {
ucc_error("failure copying out asymmetric buffer: %s",
ucc_status_string(status));
}
status = ucc_mc_free(save->scratch);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("error freeing scratch asymmetric buffer: %s",
ucc_status_string(status));
}
save->scratch = NULL;
}
status = ucc_event_manager_notify(task, UCC_EVENT_COMPLETED);
} else {
Expand Down
114 changes: 109 additions & 5 deletions src/utils/ucc_coll_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ ucc_coll_args_is_mem_symmetric(const ucc_coll_args_t *args,
}


/* If this is the root and the src/dst buffers are asymmetric, the dst needs
to have a new allocation to make the mem types match. On task completion,
copy the result back into the old dst */
/* If this is the root and the src/dst buffers are asymmetric, one buffer needs
to have a new allocation to make the mem types match. If that buffer was the
dst buffer, copy the result back into the old dst on task completion */
ucc_status_t
ucc_coll_args_init_asymmetric_buffer(ucc_coll_args_t *args,
ucc_team_h team,
Expand All @@ -111,8 +111,6 @@ ucc_coll_args_init_asymmetric_buffer(ucc_coll_args_t *args,
switch (args->coll_type) {
case UCC_COLL_TYPE_REDUCE:
case UCC_COLL_TYPE_GATHER:
case UCC_COLL_TYPE_SCATTER:
case UCC_COLL_TYPE_SCATTERV:
{
ucc_memory_type_t mem_type = args->src.info.mem_type;
if (args->coll_type == UCC_COLL_TYPE_SCATTERV) {
Expand Down Expand Up @@ -150,12 +148,118 @@ ucc_coll_args_init_asymmetric_buffer(ucc_coll_args_t *args,
args->dst.info_v.mem_type = args->src.info.mem_type;
return UCC_OK;
}
case UCC_COLL_TYPE_SCATTER:
{
ucc_memory_type_t mem_type = args->dst.info.mem_type;
memcpy(&save_info->old_asymmetric_buffer.info,
&args->src.info, sizeof(ucc_coll_buffer_info_t));
status = ucc_mc_alloc(&save_info->scratch,
ucc_dt_size(args->src.info.datatype) * args->src.info.count,
mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement "
"memory for asymmetric buffer");
return status;
}
args->src.info.buffer = save_info->scratch->addr;
args->src.info.mem_type = mem_type;
return UCC_OK;
}
case UCC_COLL_TYPE_SCATTERV:
{
ucc_memory_type_t mem_type = args->dst.info.mem_type;
memcpy(&save_info->old_asymmetric_buffer.info_v,
&args->src.info_v, sizeof(ucc_coll_buffer_info_v_t));
status = ucc_mc_alloc(&save_info->scratch,
ucc_coll_args_get_total_count(args,
args->src.info_v.counts, team->size),
mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement "
"memory for asymmetric buffer");
return status;
}
args->src.info_v.buffer = save_info->scratch->addr;
args->src.info_v.mem_type = mem_type;
return UCC_OK;
}
default:
break;
}
return UCC_ERR_INVALID_PARAM;
}

ucc_status_t
ucc_coll_args_free_asymmetric_buffer(ucc_coll_task_t *task)
{
ucc_status_t status = UCC_OK;
ucc_buffer_info_asymmetric_memtype_t *save = &task->bargs.asymmetric_save_info;

if (UCC_IS_INPLACE(task->bargs.args)) {
return UCC_ERR_INVALID_PARAM;
}

if (save->scratch == NULL) {
ucc_error("failure trying to free NULL asymmetric buffer");
}

status = ucc_mc_free(save->scratch);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("error freeing scratch asymmetric buffer: %s",
ucc_status_string(status));
}
save->scratch = NULL;

return status;
}

ucc_status_t ucc_copy_asymmetric_buffer(ucc_coll_task_t *task)
{
ucc_status_t status = UCC_OK;
ucc_coll_args_t *coll_args = &task->bargs.args;
ucc_buffer_info_asymmetric_memtype_t *save = &task->bargs.asymmetric_save_info;
ucc_rank_t size = task->team->params.size;

if(task->bargs.args.coll_type == UCC_COLL_TYPE_SCATTERV) {
// copy in
status = ucc_mc_memcpy(save->scratch->addr,
save->old_asymmetric_buffer.info_v.buffer,
ucc_coll_args_get_total_count(coll_args,
coll_args->src.info_v.counts, size),
save->scratch->mt,
save->old_asymmetric_buffer.info_v.mem_type);
} else if(task->bargs.args.coll_type == UCC_COLL_TYPE_SCATTER) {
// copy in
status = ucc_mc_memcpy(save->scratch->addr,
save->old_asymmetric_buffer.info.buffer,
ucc_dt_size(save->old_asymmetric_buffer.info.datatype) *
save->old_asymmetric_buffer.info.count,
save->scratch->mt,
save->old_asymmetric_buffer.info.mem_type);
} else if(task->bargs.args.coll_type == UCC_COLL_TYPE_GATHERV) {
// copy out
status = ucc_mc_memcpy(save->old_asymmetric_buffer.info_v.buffer,
save->scratch->addr,
ucc_coll_args_get_total_count(coll_args,
coll_args->dst.info_v.counts, size),
save->old_asymmetric_buffer.info_v.mem_type,
save->scratch->mt);
} else {
// copy out
status = ucc_mc_memcpy(save->old_asymmetric_buffer.info.buffer,
save->scratch->addr,
ucc_dt_size(save->old_asymmetric_buffer.info.datatype) *
save->old_asymmetric_buffer.info.count,
save->old_asymmetric_buffer.info.mem_type,
save->scratch->mt);
}
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("error copying back to old asymmetric buffer: %s",
ucc_status_string(status));
}
return status;
}

int ucc_coll_args_is_predefined_dt(const ucc_coll_args_t *args, ucc_rank_t rank)
{
switch (args->coll_type) {
Expand Down
5 changes: 5 additions & 0 deletions src/utils/ucc_coll_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,4 +347,9 @@ ucc_coll_args_init_asymmetric_buffer(ucc_coll_args_t *args,
ucc_team_h team,
ucc_buffer_info_asymmetric_memtype_t *save_info);

ucc_status_t
ucc_coll_args_free_asymmetric_buffer(ucc_coll_task_t *task);

ucc_status_t ucc_copy_asymmetric_buffer(ucc_coll_task_t *task);

#endif

0 comments on commit cbf46d8

Please sign in to comment.