diff --git a/src/coll_patterns/double_binary_tree.h b/src/coll_patterns/double_binary_tree.h index 47e2aed520..0040321f6f 100644 --- a/src/coll_patterns/double_binary_tree.h +++ b/src/coll_patterns/double_binary_tree.h @@ -13,13 +13,13 @@ enum { }; typedef struct ucc_dbt_single_tree { - ucc_rank_t rank; - ucc_rank_t size; - ucc_rank_t root; - ucc_rank_t parent; - ucc_rank_t children[2]; - int height; - int recv; + ucc_rank_t rank; + ucc_rank_t size; + ucc_rank_t root; + ucc_rank_t parent; + ucc_rank_t children[2]; + int height; + int recv; } ucc_dbt_single_tree_t; static inline ucc_rank_t get_root(ucc_rank_t size) diff --git a/src/components/tl/ucp/bcast/bcast.c b/src/components/tl/ucp/bcast/bcast.c index 8ba6698b3a..b3b98e7779 100644 --- a/src/components/tl/ucp/bcast/bcast.c +++ b/src/components/tl/ucp/bcast/bcast.c @@ -23,7 +23,7 @@ ucc_base_coll_alg_info_t {.id = UCC_TL_UCP_BCAST_ALG_DBT, .name = "dbt", .desc = "bcast over double binary tree where a leaf in one tree " - "will be intermediate in other (optimized for latency)"}, + "will be intermediate in other (optimized for BW)"}, [UCC_TL_UCP_BCAST_ALG_LAST] = { .id = 0, .name = NULL, .desc = NULL}}; diff --git a/src/components/tl/ucp/bcast/bcast_dbt.c b/src/components/tl/ucp/bcast/bcast_dbt.c index 7abeb8fd3d..48b6fcca48 100644 --- a/src/components/tl/ucp/bcast/bcast_dbt.c +++ b/src/components/tl/ucp/bcast/bcast_dbt.c @@ -69,33 +69,38 @@ static void recv_completion_2(void *request, ucs_status_t status, void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, - ucc_tl_ucp_task_t); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_rank_t rank = UCC_TL_TEAM_RANK(team); - ucc_dbt_single_tree_t t1 = task->bcast_dbt.t1; - ucc_dbt_single_tree_t t2 = task->bcast_dbt.t2; - void *buffer = TASK_ARGS(task).src.info.buffer; - ucc_memory_type_t mtype = TASK_ARGS(task).src.info.mem_type; - ucc_datatype_t dt = TASK_ARGS(task).src.info.datatype; - size_t count = TASK_ARGS(task).src.info.count; - size_t data_size = count * ucc_dt_size(dt) / 2; - ucc_rank_t coll_root = (ucc_rank_t)TASK_ARGS(task).root; - ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1, - recv_completion_2}; + ucc_tl_ucp_task_t *task = + ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_rank_t rank = UCC_TL_TEAM_RANK(team); + ucc_dbt_single_tree_t t1 = task->bcast_dbt.t1; + ucc_dbt_single_tree_t t2 = task->bcast_dbt.t2; + void *buffer = args->src.info.buffer; + ucc_memory_type_t mtype = args->src.info.mem_type; + ucc_datatype_t dt = args->src.info.datatype; + size_t count = args->src.info.count; + size_t count_t1 = (count % 2) ? count / 2 + 1 + : count / 2; + size_t data_size_t1 = count_t1 * ucc_dt_size(dt); + size_t data_size_t2 = count / 2 * ucc_dt_size(dt); + ucc_rank_t coll_root = (ucc_rank_t)args->root; + ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1, + recv_completion_2}; uint32_t i; UCC_BCAST_DBT_GOTO_STATE(task->bcast_dbt.state); if (rank != t1.root && rank != coll_root) { - UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(buffer, data_size, mtype, t1.parent, - team, task, cb[0], (void *)task), + UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(buffer, data_size_t1, mtype, + t1.parent, team, task, cb[0], + (void *)task), task, out); } if (rank != t2.root && rank != coll_root) { - UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size), - data_size, mtype, t2.parent, team, + UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size_t1), + data_size_t2, mtype, t2.parent, team, task, cb[1], (void *)task), task, out); } @@ -105,7 +110,7 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task) if ((coll_root == rank) || (task->bcast_dbt.t1.recv > 0)) { for (i = 0; i < 2; i++) { if (t1.children[i] != -1 && t1.children[i] != coll_root) { - UCPCHECK_GOTO(ucc_tl_ucp_send_nb(buffer, data_size, mtype, + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(buffer, data_size_t1, mtype, t1.children[i], team, task), task, out); } @@ -119,8 +124,9 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task) if ((coll_root == rank) || (task->bcast_dbt.t2.recv > 0)) { for (i = 0; i < 2; i++) { if (t2.children[i] != -1 && t2.children[i] != coll_root) { - UCPCHECK_GOTO(ucc_tl_ucp_send_nb(PTR_OFFSET(buffer, data_size), - data_size, mtype, + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(PTR_OFFSET(buffer, + data_size_t1), + data_size_t2, mtype, t2.children[i], team, task), task, out); } @@ -144,28 +150,32 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task) ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, - ucc_tl_ucp_task_t); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_status_t status = UCC_OK; - ucc_rank_t rank = UCC_TL_TEAM_RANK(team); - void *buffer = TASK_ARGS(task).src.info.buffer; - ucc_memory_type_t mtype = TASK_ARGS(task).src.info.mem_type; - ucc_datatype_t dt = TASK_ARGS(task).src.info.datatype; - size_t count = TASK_ARGS(task).src.info.count; - size_t data_size = count * ucc_dt_size(dt) / 2; - ucc_rank_t coll_root = (ucc_rank_t)TASK_ARGS(task).root; - ucc_rank_t t1_root = task->bcast_dbt.t1.root; - ucc_rank_t t2_root = task->bcast_dbt.t2.root; - ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1, - recv_completion_2}; + ucc_tl_ucp_task_t *task = + ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_status_t status = UCC_OK; + ucc_rank_t rank = UCC_TL_TEAM_RANK(team); + void *buffer = args->src.info.buffer; + ucc_memory_type_t mtype = args->src.info.mem_type; + ucc_datatype_t dt = args->src.info.datatype; + size_t count = args->src.info.count; + size_t count_t1 = (count % 2) ? count / 2 + 1 + : count / 2; + size_t data_size_t1 = count_t1 * ucc_dt_size(dt); + size_t data_size_t2 = count / 2 * ucc_dt_size(dt); + ucc_rank_t coll_root = (ucc_rank_t)args->root; + ucc_rank_t t1_root = task->bcast_dbt.t1.root; + ucc_rank_t t2_root = task->bcast_dbt.t2.root; + ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1, + recv_completion_2}; task->bcast_dbt.t1.recv = 0; task->bcast_dbt.t2.recv = 0; ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); if (rank == coll_root && coll_root != t1_root) { - status = ucc_tl_ucp_send_nb(buffer, data_size, mtype, t1_root, team, + status = ucc_tl_ucp_send_nb(buffer, data_size_t1, mtype, t1_root, team, task); if (UCC_OK != status) { return status; @@ -173,25 +183,25 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task) } if (rank == coll_root && coll_root != t2_root) { - status = ucc_tl_ucp_send_nb(PTR_OFFSET(buffer, data_size), data_size, - mtype, t2_root, team, task); + status = ucc_tl_ucp_send_nb(PTR_OFFSET(buffer, data_size_t1), + data_size_t2, mtype, t2_root, team, task); if (UCC_OK != status) { return status; } } if (rank != coll_root && rank == t1_root) { - status = ucc_tl_ucp_recv_cb(buffer, data_size, mtype, coll_root, team, - task, cb[0], (void *)task); + status = ucc_tl_ucp_recv_cb(buffer, data_size_t1, mtype, coll_root, + team, task, cb[0], (void *)task); if (UCC_OK != status) { return status; } } if (rank != coll_root && rank == t2_root) { - status = ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size), data_size, - mtype, coll_root, team, task, cb[1], - (void *)task); + status = ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size_t1), + data_size_t2, mtype, coll_root, team, task, + cb[1], (void *)task); if (UCC_OK != status) { return status; }