Skip to content

Commit

Permalink
CORE: Implement weak asymmetric mem with gtests
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Sarkauskas authored and nsarka committed Jul 12, 2024
1 parent 7b1a71a commit df217dc
Show file tree
Hide file tree
Showing 10 changed files with 538 additions and 227 deletions.
2 changes: 2 additions & 0 deletions src/coll_score/ucc_coll_score_map.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ static ucc_status_t ucc_coll_score_map_lookup(ucc_score_map_t *map,
mt = UCC_MEMORY_TYPE_HOST;
}
if (!ucc_coll_args_is_mem_symmetric(&bargs->args, map->team_rank)) {
ucc_debug("Mem was asymmetric even though asymmetric memory should "
"be handled prior to finding coll score");
return UCC_ERR_INVALID_PARAM;
}
if (msgsize == UCC_MSG_SIZE_INVALID || msgsize == UCC_MSG_SIZE_ASYMMETRIC) {
Expand Down
4 changes: 2 additions & 2 deletions src/components/base/ucc_base_iface.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ enum {

typedef struct ucc_asymmetric_save_info {
union {
ucc_coll_buffer_info_t info; /*!< Buffer info for the collective */
ucc_coll_buffer_info_v_t info_v; /*!< Buffer info for the collective */
ucc_coll_buffer_info_t info;
ucc_coll_buffer_info_v_t info_v;
} old_asymmetric_buffer;
ucc_mc_buffer_header_t *scratch;
} ucc_asymmetric_save_info_t;
Expand Down
5 changes: 1 addition & 4 deletions src/core/ucc_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,6 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
ucc_ee_type_t coll_ee_type;
size_t coll_size;

printf("nick in ucc_collective_init: dst buffer=%p\n", coll_args->dst.info.buffer);

if (ucc_unlikely(team->state != UCC_TEAM_ACTIVE)) {
ucc_error("team %p is used before team create is completed", team);
return UCC_ERR_INVALID_PARAM;
Expand Down Expand Up @@ -250,11 +248,10 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
goto free_scratch;
}

coll_mem_type = ucc_coll_args_mem_type(&op_args.args, team->rank);

task->flags |= UCC_COLL_TASK_FLAG_TOP_LEVEL;
if (task->flags & UCC_COLL_TASK_FLAG_EXECUTOR) {
task->flags |= UCC_COLL_TASK_FLAG_EXECUTOR_STOP;
coll_mem_type = ucc_coll_args_mem_type(&op_args.args, team->rank);
switch(coll_mem_type) {
case UCC_MEMORY_TYPE_CUDA:
case UCC_MEMORY_TYPE_CUDA_MANAGED:
Expand Down
67 changes: 53 additions & 14 deletions src/schedule/ucc_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,53 @@ ucc_status_t ucc_dependency_handler(ucc_coll_task_t *parent,
ucc_status_t ucc_triggered_post(ucc_ee_h ee, ucc_ev_t *ev,
ucc_coll_task_t *task);

static inline
ucc_status_t ucc_copy_asymmetric_buffer_out(ucc_asymmetric_save_info_t *save)
{
ucc_status_t status = UCC_OK;
status = ucc_mc_memcpy(save->old_asymmetric_buffer.info.buffer,
save->scratch->addr,
ucc_dt_size(save->old_asymmetric_buffer.info.datatype) *
save->old_asymmetric_buffer.info.count,
save->old_asymmetric_buffer.info.mem_type,
save->scratch->mt);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("error copying back to old asymmetric buffer: %s",
ucc_status_string(status));
}
status = ucc_mc_free(save->scratch);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("error freeing scratch asymmetric buffer: %s",
ucc_status_string(status));
}
return status;
}

static inline
ucc_status_t ucc_copy_asymmetric_buffer_v_out(ucc_coll_task_t *task)
{
ucc_status_t status = UCC_OK;
ucc_coll_args_t *coll_args = &task->bargs.args;
ucc_asymmetric_save_info_t *save = &task->bargs.asymmetric_save_info;
ucc_rank_t size = task->team->params.size;
status = ucc_mc_memcpy(save->old_asymmetric_buffer.info_v.buffer,
save->scratch->addr,
ucc_coll_args_get_total_count(coll_args,
coll_args->dst.info_v.counts, size),
save->old_asymmetric_buffer.info_v.mem_type,
save->scratch->mt);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("error copying back to old asymmetric buffer: %s",
ucc_status_string(status));
}
status = ucc_mc_free(save->scratch);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("error freeing scratch asymmetric buffer: %s",
ucc_status_string(status));
}
return status;
}

static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task)
{
ucc_status_t status = task->status;
Expand All @@ -188,20 +235,12 @@ static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task)
if (ucc_likely(status == UCC_OK)) {
ucc_asymmetric_save_info_t *save = &task->bargs.asymmetric_save_info;
if (save->scratch != NULL) {
status = ucc_mc_memcpy(save->old_asymmetric_buffer.info.buffer,
save->scratch->addr,
ucc_dt_size(save->old_asymmetric_buffer.info.datatype) *
save->old_asymmetric_buffer.info.count,
save->old_asymmetric_buffer.info.mem_type,
save->scratch->mt);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("error copying back to old asymmetric buffer: %s",
ucc_status_string(status));
}
status = ucc_mc_free(save->scratch);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("error freeing scratch asymmetric buffer: %s",
ucc_status_string(status));
status = (task->bargs.args.coll_type == UCC_COLL_TYPE_GATHERV) ?
ucc_copy_asymmetric_buffer_v_out(task) :
ucc_copy_asymmetric_buffer_out(save);
if (status != UCC_OK) {
ucc_error("failure copying out asymmetric buffer: %s",
ucc_status_string(status));
}
save->scratch = NULL;
}
Expand Down
114 changes: 28 additions & 86 deletions src/utils/ucc_coll_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,119 +95,61 @@ ucc_coll_args_is_mem_symmetric(const ucc_coll_args_t *args,
}


/* If the src/dst buffers are asymmetric, one of them needs to have a new
allocation */
/* If this is the root and the src/dst buffers are asymmetric, the dst needs
to have a new allocation to make the mem types match. On task completion,
copy the result back into the old dst */
ucc_status_t
ucc_coll_args_update_asymmetric_buffer(ucc_coll_args_t *args,
ucc_team_h team,
ucc_asymmetric_save_info_t *save_info)
{
//ucc_rank_t root = args->root;
//ucc_rank_t rank = team->rank;
ucc_status_t status = UCC_OK;

if (UCC_IS_INPLACE(*args)) {
return UCC_ERR_INVALID_PARAM;
}
switch (args->coll_type) {
case UCC_COLL_TYPE_BARRIER:
case UCC_COLL_TYPE_BCAST:
case UCC_COLL_TYPE_FANIN:
case UCC_COLL_TYPE_FANOUT:
return UCC_ERR_INVALID_PARAM;
case UCC_COLL_TYPE_ALLTOALL:
case UCC_COLL_TYPE_ALLREDUCE:
case UCC_COLL_TYPE_ALLGATHER:
case UCC_COLL_TYPE_REDUCE_SCATTER:
/*{
if (args->dst.info.mem_type == mem_type) {
*old_buffer = args->src.info.buffer;
status = ucc_mc_alloc(args->src.info.buffer, ucc_dt_size(args->src.info.datatype) * args->src.info.count, mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement memory for asymmetric buffer");
return status;
}
} else {
*old_buffer = args->dst.info.buffer;
status = ucc_mc_alloc(args->dst.info.buffer, ucc_dt_size(args->dst.info.datatype) * args->dst.info.count, mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement memory for asymmetric buffer");
return status;
}
}
}*/
case UCC_COLL_TYPE_ALLGATHERV:
case UCC_COLL_TYPE_REDUCE_SCATTERV:
/*{
if (args->dst.info_v.mem_type == mem_type) {
*old_buffer = args->src.info.buffer;
status = ucc_mc_alloc(args->src.info.buffer, ucc_dt_size(args->src.info.datatype) * args->src.info.count, mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement memory for asymmetric buffer");
return status;
}
} else {
*old_buffer = args->dst.info_v.buffer;
status = ucc_mc_alloc(args->dst.info_v.buffer, ucc_dt_size(args->dst.info_v.datatype) * args->dst.info_v.counts[rank], mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement memory for asymmetric buffer");
return status;
}
}
}*/
case UCC_COLL_TYPE_ALLTOALLV:
/*{
ucc_count_t sum_count = 0;
ucc_rank_t i;
if (args->dst.info_v.mem_type == mem_type) {
*old_buffer = args->src.info_v.buffer;
for(i = 0; i < team->size; i++) {
sum_count += args->src.info_v.counts[i];
}
status = ucc_mc_alloc(args->src.info_v.buffer, ucc_dt_size(args->src.info_v.datatype) * sum_count, mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement memory for asymmetric buffer");
return status;
}
} else {
*old_buffer = args->dst.info_v.buffer;
for(i = 0; i < team->size; i++) {
sum_count += args->dst.info_v.counts[i];
}
status = ucc_mc_alloc(args->dst.info_v.buffer, ucc_dt_size(args->dst.info_v.datatype) * sum_count, mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement memory for asymmetric buffer");
return status;
}
}
}*/
case UCC_COLL_TYPE_REDUCE:
case UCC_COLL_TYPE_GATHER:
case UCC_COLL_TYPE_SCATTER:
case UCC_COLL_TYPE_SCATTERV:
{
ucc_memory_type_t mem_type = args->src.info.mem_type;
if (args->coll_type == UCC_COLL_TYPE_SCATTERV) {
mem_type = args->src.info_v.mem_type;
}
memcpy(&save_info->old_asymmetric_buffer.info,
&args->dst.info, sizeof(ucc_coll_buffer_info_t));
status = ucc_mc_alloc(&save_info->scratch,
ucc_dt_size(args->dst.info.datatype) *
args->dst.info.count,
args->src.info.mem_type);
mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement "
"memory for asymmetric buffer");
return status;
}
args->dst.info.buffer = save_info->scratch->addr;
args->dst.info.mem_type = args->src.info.mem_type;
args->dst.info.mem_type = mem_type;
return UCC_OK;
}
case UCC_COLL_TYPE_GATHERV:
{
memcpy(&save_info->old_asymmetric_buffer.info_v,
&args->dst.info_v, sizeof(ucc_coll_buffer_info_v_t));
status = ucc_mc_alloc(&save_info->scratch,
ucc_coll_args_get_total_count(args,
args->dst.info_v.counts, team->size),
args->src.info.mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement "
"memory for asymmetric buffer");
return status;
}
args->dst.info_v.buffer = save_info->scratch->addr;
args->dst.info_v.mem_type = args->src.info.mem_type;
return UCC_OK;
}
case UCC_COLL_TYPE_GATHERV: /*
return (root != rank ? NULL : (
args->dst.info_v.mem_type == mem_type ? args->src.info.buffer : args->dst.info_v.buffer;
)) */
case UCC_COLL_TYPE_SCATTERV: /*
return (root != rank ? NULL : (
args->dst.info.mem_type == mem_type ? args->src.info_v.buffer : args->dst.info.buffer;
))*/
default:
break;
}
Expand Down Expand Up @@ -290,8 +232,8 @@ ucc_memory_type_t ucc_coll_args_mem_type(const ucc_coll_args_t *args,
return UCC_MEMORY_TYPE_NOT_APPLY;
case UCC_COLL_TYPE_BCAST:
return args->src.info.mem_type;
case UCC_COLL_TYPE_ALLREDUCE:
case UCC_COLL_TYPE_ALLTOALL:
case UCC_COLL_TYPE_ALLREDUCE:
case UCC_COLL_TYPE_ALLGATHER:
case UCC_COLL_TYPE_REDUCE_SCATTER:
return args->dst.info.mem_type;
Expand Down
3 changes: 2 additions & 1 deletion test/gtest/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ gtest_SOURCES = \
coll_score/test_score.cc \
coll_score/test_score_str.cc \
coll_score/test_score_update.cc \
active_set/test_active_set.cc
active_set/test_active_set.cc \
asym_mem/test_asymmetric_memory.cc

if TL_MLX5_ENABLED
gtest_SOURCES += tl/mlx5/test_tl_mlx5.cc \
Expand Down
Loading

0 comments on commit df217dc

Please sign in to comment.