Skip to content

Commit

Permalink
TL/UCP: odd count fix
Browse files Browse the repository at this point in the history
  • Loading branch information
shimmybalsam committed Oct 5, 2023
1 parent fba15f5 commit af1c04f
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 52 deletions.
14 changes: 7 additions & 7 deletions src/coll_patterns/double_binary_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ enum {
};

typedef struct ucc_dbt_single_tree {
ucc_rank_t rank;
ucc_rank_t size;
ucc_rank_t root;
ucc_rank_t parent;
ucc_rank_t children[2];
int height;
int recv;
ucc_rank_t rank;
ucc_rank_t size;
ucc_rank_t root;
ucc_rank_t parent;
ucc_rank_t children[2];
int height;
int recv;
} ucc_dbt_single_tree_t;

static inline ucc_rank_t get_root(ucc_rank_t size)
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/ucp/bcast/bcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ucc_base_coll_alg_info_t
{.id = UCC_TL_UCP_BCAST_ALG_DBT,
.name = "dbt",
.desc = "bcast over double binary tree where a leaf in one tree "
"will be intermediate in other (optimized for latency)"},
"will be intermediate in other (optimized for BW)"},
[UCC_TL_UCP_BCAST_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

Expand Down
98 changes: 54 additions & 44 deletions src/components/tl/ucp/bcast/bcast_dbt.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,33 +69,38 @@ static void recv_completion_2(void *request, ucs_status_t status,

void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_dbt_single_tree_t t1 = task->bcast_dbt.t1;
ucc_dbt_single_tree_t t2 = task->bcast_dbt.t2;
void *buffer = TASK_ARGS(task).src.info.buffer;
ucc_memory_type_t mtype = TASK_ARGS(task).src.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).src.info.datatype;
size_t count = TASK_ARGS(task).src.info.count;
size_t data_size = count * ucc_dt_size(dt) / 2;
ucc_rank_t coll_root = (ucc_rank_t)TASK_ARGS(task).root;
ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1,
recv_completion_2};
ucc_tl_ucp_task_t *task =
ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_dbt_single_tree_t t1 = task->bcast_dbt.t1;
ucc_dbt_single_tree_t t2 = task->bcast_dbt.t2;
void *buffer = args->src.info.buffer;
ucc_memory_type_t mtype = args->src.info.mem_type;
ucc_datatype_t dt = args->src.info.datatype;
size_t count = args->src.info.count;
size_t count_t1 = (count % 2) ? count / 2 + 1
: count / 2;
size_t data_size_t1 = count_t1 * ucc_dt_size(dt);
size_t data_size_t2 = count / 2 * ucc_dt_size(dt);
ucc_rank_t coll_root = (ucc_rank_t)args->root;
ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1,
recv_completion_2};
uint32_t i;

UCC_BCAST_DBT_GOTO_STATE(task->bcast_dbt.state);

if (rank != t1.root && rank != coll_root) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(buffer, data_size, mtype, t1.parent,
team, task, cb[0], (void *)task),
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(buffer, data_size_t1, mtype,
t1.parent, team, task, cb[0],
(void *)task),
task, out);
}

if (rank != t2.root && rank != coll_root) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size),
data_size, mtype, t2.parent, team,
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size_t1),
data_size_t2, mtype, t2.parent, team,
task, cb[1], (void *)task),
task, out);
}
Expand All @@ -105,7 +110,7 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
if ((coll_root == rank) || (task->bcast_dbt.t1.recv > 0)) {
for (i = 0; i < 2; i++) {
if (t1.children[i] != -1 && t1.children[i] != coll_root) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(buffer, data_size, mtype,
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(buffer, data_size_t1, mtype,
t1.children[i], team, task),
task, out);
}
Expand All @@ -119,8 +124,9 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
if ((coll_root == rank) || (task->bcast_dbt.t2.recv > 0)) {
for (i = 0; i < 2; i++) {
if (t2.children[i] != -1 && t2.children[i] != coll_root) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(PTR_OFFSET(buffer, data_size),
data_size, mtype,
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(PTR_OFFSET(buffer,
data_size_t1),
data_size_t2, mtype,
t2.children[i], team, task),
task, out);
}
Expand All @@ -144,54 +150,58 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)

ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_status_t status = UCC_OK;
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
void *buffer = TASK_ARGS(task).src.info.buffer;
ucc_memory_type_t mtype = TASK_ARGS(task).src.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).src.info.datatype;
size_t count = TASK_ARGS(task).src.info.count;
size_t data_size = count * ucc_dt_size(dt) / 2;
ucc_rank_t coll_root = (ucc_rank_t)TASK_ARGS(task).root;
ucc_rank_t t1_root = task->bcast_dbt.t1.root;
ucc_rank_t t2_root = task->bcast_dbt.t2.root;
ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1,
recv_completion_2};
ucc_tl_ucp_task_t *task =
ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_status_t status = UCC_OK;
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
void *buffer = args->src.info.buffer;
ucc_memory_type_t mtype = args->src.info.mem_type;
ucc_datatype_t dt = args->src.info.datatype;
size_t count = args->src.info.count;
size_t count_t1 = (count % 2) ? count / 2 + 1
: count / 2;
size_t data_size_t1 = count_t1 * ucc_dt_size(dt);
size_t data_size_t2 = count / 2 * ucc_dt_size(dt);
ucc_rank_t coll_root = (ucc_rank_t)args->root;
ucc_rank_t t1_root = task->bcast_dbt.t1.root;
ucc_rank_t t2_root = task->bcast_dbt.t2.root;
ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1,
recv_completion_2};

task->bcast_dbt.t1.recv = 0;
task->bcast_dbt.t2.recv = 0;
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

if (rank == coll_root && coll_root != t1_root) {
status = ucc_tl_ucp_send_nb(buffer, data_size, mtype, t1_root, team,
status = ucc_tl_ucp_send_nb(buffer, data_size_t1, mtype, t1_root, team,
task);
if (UCC_OK != status) {
return status;
}
}

if (rank == coll_root && coll_root != t2_root) {
status = ucc_tl_ucp_send_nb(PTR_OFFSET(buffer, data_size), data_size,
mtype, t2_root, team, task);
status = ucc_tl_ucp_send_nb(PTR_OFFSET(buffer, data_size_t1),
data_size_t2, mtype, t2_root, team, task);
if (UCC_OK != status) {
return status;
}
}

if (rank != coll_root && rank == t1_root) {
status = ucc_tl_ucp_recv_cb(buffer, data_size, mtype, coll_root, team,
task, cb[0], (void *)task);
status = ucc_tl_ucp_recv_cb(buffer, data_size_t1, mtype, coll_root,
team, task, cb[0], (void *)task);
if (UCC_OK != status) {
return status;
}
}

if (rank != coll_root && rank == t2_root) {
status = ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size), data_size,
mtype, coll_root, team, task, cb[1],
(void *)task);
status = ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size_t1),
data_size_t2, mtype, coll_root, team, task,
cb[1], (void *)task);
if (UCC_OK != status) {
return status;
}
Expand Down

0 comments on commit af1c04f

Please sign in to comment.