Skip to content

Commit

Permalink
CORE: fix memory type score update (#650)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev authored Oct 11, 2022
1 parent ca526d5 commit ad9ff51
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 50 deletions.
30 changes: 21 additions & 9 deletions src/coll_score/ucc_coll_score.c
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ static ucc_status_t str_to_tsizes(const char *str, ucc_rank_t **tsizes,

static ucc_status_t ucc_coll_score_parse_str(const char *str,
ucc_coll_score_t *score,
ucc_rank_t team_size, //NOLINT
ucc_rank_t team_size,
ucc_base_coll_init_fn_t init,
ucc_base_team_t *team,
ucc_alg_id_to_init_fn_t alg_fn)
Expand Down Expand Up @@ -720,7 +720,7 @@ static ucc_status_t ucc_coll_score_parse_str(const char *str,
}
if (tsizes) {
/* Team size qualifier was provided: check if we should apply this
str setting to the current team */
str setting to the current team */
ts_skip = 1;
for (i = 0; i < n_tsizes; i++) {
if (team_size >= tsizes[2 * i] && team_size <= tsizes[2 * i + 1]) {
Expand Down Expand Up @@ -984,16 +984,26 @@ static ucc_status_t ucc_coll_score_update_one(ucc_list_link_t *dest,
return status;
}

ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score,
ucc_coll_score_t *update,
ucc_score_t default_score)
ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score,
ucc_coll_score_t *update,
ucc_score_t default_score,
ucc_memory_type_t *mtypes,
int mt_n)
{
ucc_status_t status;
int i, j;
ucc_memory_type_t mt;

if (mt_n == 0) {
mt_n = UCC_MEMORY_TYPE_LAST;
}

for (i = 0; i < UCC_COLL_TYPE_NUM; i++) {
for (j = 0; j < UCC_MEMORY_TYPE_LAST; j++) {
for (j = 0; j < mt_n; j++) {
mt = (mtypes == NULL ? j : mtypes[j]);
status = ucc_coll_score_update_one(
&score->scores[i][j], &update->scores[i][j], default_score);
&score->scores[i][mt],
&update->scores[i][mt], default_score);
if (UCC_OK != status) {
return status;
}
Expand All @@ -1008,7 +1018,9 @@ ucc_status_t ucc_coll_score_update_from_str(const char * str,
ucc_base_coll_init_fn_t init,
ucc_base_team_t *team,
ucc_score_t def_score,
ucc_alg_id_to_init_fn_t alg_fn)
ucc_alg_id_to_init_fn_t alg_fn,
ucc_memory_type_t *mtypes,
int mt_n)
{
ucc_status_t status;
ucc_coll_score_t *score_str;
Expand All @@ -1017,7 +1029,7 @@ ucc_status_t ucc_coll_score_update_from_str(const char * str,
if (UCC_OK != status) {
return status;
}
status = ucc_coll_score_update(score, score_str, def_score);
status = ucc_coll_score_update(score, score_str, def_score, mtypes, mt_n);
ucc_coll_score_free(score_str);
return status;
}
Expand Down
18 changes: 12 additions & 6 deletions src/coll_score/ucc_coll_score.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ ucc_status_t ucc_coll_score_alloc_from_str(const char * str,
is provided in "str" and it does not have "score" qualifier then def_score
is used for it. If the new range is provided in "str" and it does not have
"alg_id" qualifier than "init" fn is used otherwise "init" is taken from
alg_fn mapper callback.
alg_fn mapper callback. "mtypes" parameter determines which memory types
will be udpated.
This function has 2 usages (see tl_ucp_team.c: ucc_tl_ucp_team_get_scores
function):
Expand All @@ -117,15 +118,13 @@ ucc_status_t ucc_coll_score_update_from_str(const char * str,
ucc_base_coll_init_fn_t init,
ucc_base_team_t *team,
ucc_score_t def_score,
ucc_alg_id_to_init_fn_t alg_fn);
ucc_alg_id_to_init_fn_t alg_fn,
ucc_memory_type_t *mtypes,
int mt_n);

ucc_status_t ucc_coll_score_merge_in(ucc_coll_score_t **dst,
ucc_coll_score_t *src);

ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score,
ucc_coll_score_t *update,
ucc_score_t default_score);

/* Initializes the default score datastruct with a set of coll_types specified
as a bitmap, mem_types passed as array, default score value and default init fn.
The collective will have msg range 0-inf. */
Expand Down Expand Up @@ -155,4 +154,11 @@ void ucc_coll_score_set(ucc_coll_score_t *score,
ucc_score_t value);

void ucc_coll_score_map_print_info(const ucc_score_map_t *score);

ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score,
ucc_coll_score_t *update,
ucc_score_t default_score,
ucc_memory_type_t *mtypes,
int mt_n);

#endif
2 changes: 1 addition & 1 deletion src/components/cl/basic/cl_basic_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ ucc_status_t ucc_cl_basic_team_get_scores(ucc_base_team_t *cl_team,
if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, *score, UCC_CL_TEAM_SIZE(team), NULL, cl_team,
UCC_CL_BASIC_DEFAULT_SCORE, NULL);
UCC_CL_BASIC_DEFAULT_SCORE, NULL, NULL, 0);

/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
Expand Down
4 changes: 2 additions & 2 deletions src/components/cl/hier/cl_hier_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team,
status = ucc_coll_score_update_from_str(
ucc_cl_hier_default_alg_select_str[i], score,
UCC_TL_TEAM_SIZE(team), ucc_cl_hier_coll_init, &team->super.super,
UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init);
UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0);
if (UCC_OK != status) {
cl_error(lib, "failed to apply default coll select setting: %s",
ucc_cl_hier_default_alg_select_str[i]);
Expand All @@ -358,7 +358,7 @@ ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team,
if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_CL_TEAM_SIZE(team), NULL, cl_team,
UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init);
UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0);

/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/cuda/tl_cuda_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ ucc_status_t ucc_tl_cuda_team_get_scores(ucc_base_team_t *tl_team,
status = ucc_coll_score_update_from_str(
ucc_tl_cuda_default_alg_select_str[i], score,
UCC_TL_TEAM_SIZE(team), ucc_tl_cuda_coll_init, &team->super.super,
UCC_TL_CUDA_DEFAULT_SCORE, ucc_tl_cuda_alg_id_to_init);
UCC_TL_CUDA_DEFAULT_SCORE, ucc_tl_cuda_alg_id_to_init, &mt, 1);
if (UCC_OK != status) {
tl_error(tl_team->context->lib,
"failed to apply default coll select setting: %s",
Expand All @@ -364,7 +364,7 @@ ucc_status_t ucc_tl_cuda_team_get_scores(ucc_base_team_t *tl_team,
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team),
ucc_tl_cuda_coll_init, &team->super.super,
UCC_TL_CUDA_DEFAULT_SCORE, ucc_tl_cuda_alg_id_to_init);
UCC_TL_CUDA_DEFAULT_SCORE, ucc_tl_cuda_alg_id_to_init, &mt, 1);
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
goto err;
Expand Down
6 changes: 3 additions & 3 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) Facebook, Inc. and its affiliates. 2021.
*
* See file LICENSE for terms.
Expand Down Expand Up @@ -215,7 +215,7 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team,
status = ucc_coll_score_update_from_str(
ucc_tl_nccl_default_alg_select_str[i], score, UCC_TL_TEAM_SIZE(team),
ucc_tl_nccl_coll_init, &team->super.super, UCC_TL_NCCL_DEFAULT_SCORE,
ucc_tl_nccl_alg_id_to_init);
ucc_tl_nccl_alg_id_to_init, &mt, 1);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(tl_team->context->lib,
"failed to apply default coll select setting: %s",
Expand All @@ -237,7 +237,7 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team,
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team),
ucc_tl_nccl_coll_init, &team->super.super,
UCC_TL_NCCL_DEFAULT_SCORE, ucc_tl_nccl_alg_id_to_init);
UCC_TL_NCCL_DEFAULT_SCORE, ucc_tl_nccl_alg_id_to_init, &mt, 1);
/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/rccl/tl_rccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ ucc_status_t ucc_tl_rccl_team_get_scores(ucc_base_team_t *tl_team,
status = ucc_coll_score_update_from_str(
ucc_tl_rccl_default_alg_select_str[i], score, UCC_TL_TEAM_SIZE(team),
ucc_tl_rccl_coll_init, &team->super.super, UCC_TL_RCCL_DEFAULT_SCORE,
ucc_tl_rccl_alg_id_to_init);
ucc_tl_rccl_alg_id_to_init, &mt, 1);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(tl_team->context->lib,
"failed to apply default coll select setting: %s",
Expand All @@ -238,7 +238,7 @@ ucc_status_t ucc_tl_rccl_team_get_scores(ucc_base_team_t *tl_team,
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team),
ucc_tl_rccl_coll_init, &team->super.super,
UCC_TL_RCCL_DEFAULT_SCORE, ucc_tl_rccl_alg_id_to_init);
UCC_TL_RCCL_DEFAULT_SCORE, ucc_tl_rccl_alg_id_to_init, &mt, 1);
/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/self/tl_self_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ ucc_status_t ucc_tl_self_team_get_scores(ucc_base_team_t *tl_team,
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team),
ucc_tl_self_coll_init, &team->super.super,
UCC_TL_SELF_DEFAULT_SCORE, NULL);
UCC_TL_SELF_DEFAULT_SCORE, NULL, NULL, 0);
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
goto err;
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/sharp/tl_sharp_team.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -218,7 +218,7 @@ ucc_status_t ucc_tl_sharp_team_get_scores(ucc_base_team_t *tl_team,
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team),
ucc_tl_sharp_coll_init, &team->super.super,
UCC_TL_SHARP_DEFAULT_SCORE, NULL);
UCC_TL_SHARP_DEFAULT_SCORE, NULL, NULL, 0);
/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
Expand Down
8 changes: 5 additions & 3 deletions src/components/tl/ucp/tl_ucp_team.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -155,23 +155,25 @@ ucc_status_t ucc_tl_ucp_team_get_scores(ucc_base_team_t *tl_team,
if (UCC_OK != status) {
return status;
}

for (i = 0; i < UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR; i++) {
status = ucc_coll_score_update_from_str(
ucc_tl_ucp_default_alg_select_str[i], score, UCC_TL_TEAM_SIZE(team),
ucc_tl_ucp_coll_init, &team->super.super, UCC_TL_UCP_DEFAULT_SCORE,
ucc_tl_ucp_alg_id_to_init);
ucc_tl_ucp_alg_id_to_init, mem_types, mt_n);
if (UCC_OK != status) {
tl_error(tl_team->context->lib,
"failed to apply default coll select setting: %s",
ucc_tl_ucp_default_alg_select_str[i]);
goto err;
}
}

if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team), NULL,
&team->super.super, UCC_TL_UCP_DEFAULT_SCORE,
ucc_tl_ucp_alg_id_to_init);
ucc_tl_ucp_alg_id_to_init, mem_types, mt_n);

/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
Expand Down
Loading

0 comments on commit ad9ff51

Please sign in to comment.