Skip to content

Commit

Permalink
TL/UCP: bcast active_set size greater than 2
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarka committed Mar 11, 2024
1 parent 9e0d759 commit a7e68af
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 97 deletions.
49 changes: 33 additions & 16 deletions src/components/tl/ucp/bcast/bcast_dbt.c
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_rank_t rank = task->subset.myrank;
ucc_dbt_single_tree_t t1 = task->bcast_dbt.t1;
ucc_dbt_single_tree_t t2 = task->bcast_dbt.t2;
void *buffer = args->src.info.buffer;
Expand All @@ -93,14 +93,19 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)

if (rank != t1.root && rank != coll_root) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(buffer, data_size_t1, mtype,
t1.parent, team, task, cb[0],
ucc_ep_map_eval(task->subset.map,
t1.parent),
team, task, cb[0],
(void *)task),
task, out);
}

if (rank != t2.root && rank != coll_root) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size_t1),
data_size_t2, mtype, t2.parent, team,
data_size_t2, mtype,
ucc_ep_map_eval(task->subset.map,
t2.parent),
team,
task, cb[1], (void *)task),
task, out);
}
Expand All @@ -114,7 +119,10 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
if ((t1.children[i] != UCC_RANK_INVALID) &&
(t1.children[i] != coll_root)) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(buffer, data_size_t1, mtype,
t1.children[i], team, task),
ucc_ep_map_eval(
task->subset.map,
t1.children[i]),
team, task),
task, out);
}
}
Expand All @@ -133,7 +141,10 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(PTR_OFFSET(buffer,
data_size_t1),
data_size_t2, mtype,
t2.children[i], team, task),
ucc_ep_map_eval(
task->subset.map,
t2.children[i]),
team, task),
task, out);
}
}
Expand Down Expand Up @@ -161,7 +172,7 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task)
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_status_t status = UCC_OK;
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_rank_t rank = task->subset.myrank;
void *buffer = args->src.info.buffer;
ucc_memory_type_t mtype = args->src.info.mem_type;
ucc_datatype_t dt = args->src.info.datatype;
Expand All @@ -181,23 +192,28 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task)
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

if (rank == coll_root && coll_root != t1_root) {
status = ucc_tl_ucp_send_nb(buffer, data_size_t1, mtype, t1_root, team,
task);
status = ucc_tl_ucp_send_nb(buffer, data_size_t1, mtype,
ucc_ep_map_eval(task->subset.map, t1_root),
team, task);
if (UCC_OK != status) {
return status;
}
}

if (rank == coll_root && coll_root != t2_root) {
status = ucc_tl_ucp_send_nb(PTR_OFFSET(buffer, data_size_t1),
data_size_t2, mtype, t2_root, team, task);
data_size_t2, mtype,
ucc_ep_map_eval(task->subset.map, t2_root),
team, task);
if (UCC_OK != status) {
return status;
}
}

if (rank != coll_root && rank == t1_root) {
status = ucc_tl_ucp_recv_cb(buffer, data_size_t1, mtype, coll_root,
status = ucc_tl_ucp_recv_cb(buffer, data_size_t1, mtype,
ucc_ep_map_eval(task->subset.map,
coll_root),
team, task, cb[0], (void *)task);
if (UCC_OK != status) {
return status;
Expand All @@ -206,8 +222,10 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task)

if (rank != coll_root && rank == t2_root) {
status = ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size_t1),
data_size_t2, mtype, coll_root, team, task,
cb[1], (void *)task);
data_size_t2, mtype,
ucc_ep_map_eval(task->subset.map,
coll_root),
team, task, cb[1], (void *)task);
if (UCC_OK != status) {
return status;
}
Expand All @@ -227,7 +245,6 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_init(
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team;
ucc_tl_ucp_task_t *task;
ucc_rank_t rank, size;

Expand All @@ -236,9 +253,9 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_init(
task->super.progress = ucc_tl_ucp_bcast_dbt_progress;
task->super.finalize = ucc_tl_ucp_bcast_dbt_finalize;
task->n_polls = ucc_max(1, task->n_polls);
tl_team = TASK_TEAM(task);
rank = UCC_TL_TEAM_RANK(tl_team);
size = UCC_TL_TEAM_SIZE(tl_team);
rank = task->subset.myrank;
size = (ucc_rank_t)task->subset.map.ep_num;

ucc_dbt_build_trees(rank, size, &task->bcast_dbt.t1,
&task->bcast_dbt.t2);

Expand Down
8 changes: 2 additions & 6 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,20 +340,16 @@ ucc_tl_ucp_init_task(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team)
ucc_coll_task_init(&task->super, coll_args, team);

if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) {
task->tagged.tag = (coll_args->mask & UCC_COLL_ARGS_FIELD_TAG)
task->tagged.tag = (coll_args->args.mask & UCC_COLL_ARGS_FIELD_TAG)
? coll_args->args.tag : UCC_TL_UCP_ACTIVE_SET_TAG;
task->flags |= UCC_TL_UCP_TASK_FLAG_SUBSET;
task->subset.map = ucc_active_set_to_ep_map(&coll_args->args);
task->subset.myrank =
ucc_ep_map_local_rank(task->subset.map,
UCC_TL_TEAM_RANK(tl_team));
ucc_assert(coll_args->args.coll_type == UCC_COLL_TYPE_BCAST);
/* root value in args corresponds to the original team ranks,
need to convert to subset local value */
TASK_ARGS(task).root = ucc_ep_map_local_rank(task->subset.map,
coll_args->args.root);
} else {
if (coll_args->mask & UCC_COLL_ARGS_FIELD_TAG) {
if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_TAG) {
task->tagged.tag = coll_args->args.tag;
} else {
tl_team->seq_num = (tl_team->seq_num + 1) % UCC_TL_UCP_MAX_COLL_TAG;
Expand Down
5 changes: 2 additions & 3 deletions src/core/ucc_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,8 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
}

if (UCC_COLL_ARGS_ACTIVE_SET(coll_args) &&
((UCC_COLL_TYPE_BCAST != coll_args->coll_type) ||
coll_args->active_set.size != 2)) {
ucc_warn("Active Sets are only supported for bcast and set size = 2");
(UCC_COLL_TYPE_BCAST != coll_args->coll_type)) {
ucc_warn("Active Sets are only supported for bcast");
return UCC_ERR_NOT_SUPPORTED;
}

Expand Down
Loading

0 comments on commit a7e68af

Please sign in to comment.