diff --git a/src/components/ec/cpu/ec_cpu_reduce.c b/src/components/ec/cpu/ec_cpu_reduce.c index cb77cfbe8d..8aa0c3cbe5 100644 --- a/src/components/ec/cpu/ec_cpu_reduce.c +++ b/src/components/ec/cpu/ec_cpu_reduce.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -8,9 +8,10 @@ #include "ec_cpu.h" #include -#define DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, OP) \ +#define DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, OP) \ do { \ size_t _i, _j; \ + type _tmp; \ switch (_n_srcs) { \ case 2: \ for (_i = 0; _i < _count; _i++) { \ @@ -53,13 +54,12 @@ break; \ default: \ for (_i = 0; _i < _count; _i++) { \ - d[_i] = OP##_8(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \ - s[4][_i], s[5][_i], s[6][_i], s[7][_i]); \ - } \ - for (_j = 8; _j < _n_srcs; _j++) { \ - for (_i = 0; _i < _count; _i++) { \ - d[_i] = OP##_2(d[_i], s[_j][_i]); \ + _tmp = OP##_8(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \ + s[4][_i], s[5][_i], s[6][_i], s[7][_i]); \ + for (_j = 8; _j < _n_srcs; _j++) { \ + _tmp = OP##_2(_tmp, s[_j][_i]); \ } \ + d[_i] = _tmp; \ } \ break; \ } \ @@ -80,37 +80,37 @@ switch (_op) { \ case UCC_OP_AVG: \ case UCC_OP_SUM: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_SUM); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_SUM); \ if (flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA) { \ VEC_OP(d, _count, task->alpha); \ } \ break; \ case UCC_OP_MIN: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_MIN); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_MIN); \ break; \ case UCC_OP_MAX: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_MAX); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_MAX); \ break; \ case UCC_OP_PROD: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_PROD); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_PROD); \ break; \ case UCC_OP_LAND: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_LAND); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_LAND); \ break; \ case UCC_OP_BAND: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_BAND); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_BAND); \ break; \ case UCC_OP_LOR: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_LOR); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_LOR); \ break; \ case UCC_OP_BOR: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_BOR); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_BOR); \ break; \ case UCC_OP_LXOR: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_LXOR); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_LXOR); \ break; \ case UCC_OP_BXOR: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_BXOR); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_BXOR); \ break; \ default: \ ec_error(&ucc_ec_cpu.super, \ @@ -176,16 +176,16 @@ switch (_op) { \ case UCC_OP_AVG: \ case UCC_OP_SUM: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_SUM); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_SUM); \ break; \ case UCC_OP_PROD: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_PROD); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_PROD); \ break; \ case UCC_OP_MIN: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_MIN); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_MIN); \ break; \ case UCC_OP_MAX: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_MAX); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_MAX); \ break; \ default: \ ec_error(&ucc_ec_cpu.super, \ @@ -206,10 +206,10 @@ switch (_op) { \ case UCC_OP_AVG: \ case UCC_OP_SUM: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_SUM); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_SUM); \ break; \ case UCC_OP_PROD: \ - DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_PROD); \ + DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_PROD); \ break; \ default: \ ec_error(&ucc_ec_cpu.super, \ diff --git a/src/components/tl/ucp/reduce_scatter/reduce_scatter_knomial.c b/src/components/tl/ucp/reduce_scatter/reduce_scatter_knomial.c index 3595db681b..d484ba6892 100644 --- a/src/components/tl/ucp/reduce_scatter/reduce_scatter_knomial.c +++ b/src/components/tl/ucp/reduce_scatter/reduce_scatter_knomial.c @@ -15,20 +15,40 @@ task->reduce_scatter_kn.phase = _phase; \ } while (0) -static inline void get_sbuf_rbuf(ucc_knomial_pattern_t *p, ucc_coll_args_t *args, - void *scratch, size_t block_count, +static inline void get_sbuf_rbuf(ucc_tl_ucp_task_t *task, size_t block_count, void **sbuf, void **rbuf) { - uint8_t node_type = p->node_type; - size_t dt_size = ucc_dt_size(args->dst.info.datatype); + ucc_coll_args_t *args = &TASK_ARGS(task); + size_t dt_size = ucc_dt_size(args->dst.info.datatype); + void *scratch = task->reduce_scatter_kn.scratch; + ucc_knomial_pattern_t *p = &task->reduce_scatter_kn.p; + size_t offset, local_seg_offset, local_seg_count; if (ucc_knomial_pattern_loop_first_iteration(p)) { - *sbuf = (KN_NODE_PROXY == node_type || UCC_IS_INPLACE(*args)) + *sbuf = ((KN_NODE_PROXY == p->node_type) || UCC_IS_INPLACE(*args)) ? args->dst.info.buffer: args->src.info.buffer; *rbuf = scratch; } else { *sbuf = scratch; - *rbuf = PTR_OFFSET(*sbuf, block_count * dt_size); + if (!ucc_knomial_pattern_loop_last_iteration(p) || + (task->reduce_scatter_kn.scratch_mc_header != NULL)) { + *rbuf = PTR_OFFSET(*sbuf, block_count * dt_size); + } else { + ucc_sra_kn_get_offset_and_seglen(args->dst.info.count, dt_size, + p->rank, p->size, + p->radix, &local_seg_offset, + &local_seg_count); + local_seg_offset = local_seg_offset / dt_size; + if ((local_seg_offset <= block_count) || (local_seg_count == 0)) { + *rbuf = PTR_OFFSET(*sbuf, block_count * dt_size); + } else { + offset = (local_seg_offset - block_count) % local_seg_count; + /* check we have enough space to store segments */ + ucc_assert(args->dst.info.count - (block_count + offset) >= + local_seg_count * (ucc_kn_compute_step_radix(p) - 1)); + *rbuf = PTR_OFFSET(*sbuf, (block_count + offset) * dt_size); + } + } } } @@ -39,11 +59,8 @@ void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task) ucc_coll_args_t *args = &TASK_ARGS(task); ucc_tl_ucp_team_t *team = TASK_TEAM(task); ucc_kn_radix_t radix = task->reduce_scatter_kn.p.radix; - int avg_pre_op = - UCC_TL_UCP_TEAM_LIB(team)->cfg.reduce_avg_pre_op; - uint8_t node_type = - task->reduce_scatter_kn.p.node_type; ucc_knomial_pattern_t *p = &task->reduce_scatter_kn.p; + uint8_t node_type = p->node_type; void *scratch = task->reduce_scatter_kn.scratch; void *rbuf = args->dst.info.buffer; ucc_memory_type_t mem_type = args->dst.info.mem_type; @@ -56,7 +73,7 @@ void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task) ucc_rank_t rank = UCC_TL_TEAM_RANK(team); ucc_rank_t size = UCC_TL_TEAM_SIZE(team); ptrdiff_t peer_seg_offset, local_seg_offset, offset; - ucc_rank_t peer, step_radix, local_seg_index; + ucc_rank_t peer, step_radix; ucc_status_t status; ucc_kn_radix_t loop_step; size_t block_count, peer_seg_count, local_seg_count; @@ -105,10 +122,9 @@ void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task) } while (!ucc_knomial_pattern_loop_done(p)) { block_count = ucc_sra_kn_compute_block_count(count, rank, p); - get_sbuf_rbuf(p, args, task->reduce_scatter_kn.scratch, block_count, - &sbuf, &rbuf); ucc_kn_rs_pattern_peer_seg(rank, p, &local_seg_count, &local_seg_offset); + get_sbuf_rbuf(task, block_count, &sbuf, &rbuf); for (loop_step = radix - 1; loop_step > 0; loop_step--) { peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step); if (peer == UCC_KN_PEER_NULL) { @@ -136,39 +152,30 @@ void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task) step_radix = ucc_kn_compute_step_radix(p); ucc_kn_rs_pattern_peer_seg(rank, p, &local_seg_count, &local_seg_offset); - get_sbuf_rbuf(p, args, task->reduce_scatter_kn.scratch, block_count, - &sbuf, &rbuf); + get_sbuf_rbuf(task, block_count, &sbuf, &rbuf); local_data = PTR_OFFSET(sbuf, local_seg_offset * dt_size); - reduce_data = task->reduce_scatter_kn.scratch; - is_avg = args->op == UCC_OP_AVG && - (avg_pre_op ? ucc_knomial_pattern_loop_first_iteration(p) - : ucc_knomial_pattern_loop_last_iteration(p)); + is_avg = (args->op == UCC_OP_AVG) && + (UCC_TL_UCP_TEAM_LIB(team)->cfg.reduce_avg_pre_op ? + ucc_knomial_pattern_loop_first_iteration(p) : + ucc_knomial_pattern_loop_last_iteration(p)); ucc_assert((step_radix - 1) == (task->tagged.send_posted - p->iteration * (radix - 1))); - if (!task->reduce_scatter_kn.scratch_mc_header && - ucc_knomial_pattern_loop_last_iteration(p)) { - status = ucc_dt_reduce_strided( - rbuf, PTR_OFFSET(rbuf, local_seg_count * dt_size), rbuf, - step_radix - 2, local_seg_count, local_seg_count * dt_size, - dt, args, 0, 0, task->reduce_scatter_kn.executor, - &task->reduce_scatter_kn.etask); - + if (ucc_knomial_pattern_loop_last_iteration(p)) { + ucc_sra_kn_get_offset_and_seglen(count, dt_size, rank, size, + radix, &offset, + &local_seg_count); + reduce_data = PTR_OFFSET(args->dst.info.buffer, offset); } else { - if (task->reduce_scatter_kn.scratch_mc_header && - ucc_knomial_pattern_loop_last_iteration(p)) { - ucc_sra_kn_get_offset_and_seglen(count, dt_size, rank, size, - radix, &offset, - &local_seg_count); - reduce_data = PTR_OFFSET(args->dst.info.buffer, offset); - } - status = ucc_dt_reduce_strided( - local_data, rbuf, reduce_data, step_radix - 1, - local_seg_count, local_seg_count * dt_size, dt, args, - is_avg ? UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA : 0, - AVG_ALPHA(task), task->reduce_scatter_kn.executor, - &task->reduce_scatter_kn.etask); + reduce_data = task->reduce_scatter_kn.scratch; } + + status = ucc_dt_reduce_strided(local_data, rbuf, reduce_data, + step_radix - 1, local_seg_count, + local_seg_count * dt_size, dt, args, + is_avg ? UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA : 0, + AVG_ALPHA(task), task->reduce_scatter_kn.executor, + &task->reduce_scatter_kn.etask); if (ucc_unlikely(UCC_OK != status)) { tl_error(UCC_TASK_LIB(task), "failed to perform dt reduction"); task->super.status = status; @@ -181,42 +188,7 @@ void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task) } ucc_kn_rs_pattern_next_iter(p); } - - if (!task->reduce_scatter_kn.scratch_mc_header) { - ucc_knomial_pattern_prev_iteration(p); - get_sbuf_rbuf(p, args, task->reduce_scatter_kn.scratch, block_count, - &sbuf, &rbuf); - - step_radix = ucc_kn_compute_step_radix(p); - local_seg_index = ucc_kn_compute_seg_index( - rank, p->radix_pow, p); - local_seg_count = ucc_sra_kn_compute_seg_size( - block_count, step_radix, local_seg_index); - local_seg_offset = ucc_sra_kn_compute_seg_offset( - block_count, step_radix, local_seg_index); - local_data = PTR_OFFSET(sbuf, local_seg_offset * dt_size); - is_avg = args->op == UCC_OP_AVG && - (avg_pre_op ? ucc_knomial_pattern_loop_first_iteration(p) - : ucc_knomial_pattern_loop_last_iteration(p)); - - ucc_sra_kn_get_offset_and_seglen(count, dt_size, rank, size, radix, - &offset, &local_seg_count); - status = ucc_dt_reduce(local_data, rbuf, - PTR_OFFSET(args->dst.info.buffer, offset), - local_seg_count, dt, args, - is_avg ? UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA : 0, - AVG_ALPHA(task), - task->reduce_scatter_kn.executor, - &task->reduce_scatter_kn.etask); - if (ucc_unlikely(status != UCC_OK)) { - tl_error(UCC_TASK_LIB(task), "failed to reduce data to dst buffer"); - task->super.status = status; - return; - } - UCC_KN_PHASE_COMPLETE: - EXEC_TASK_TEST(UCC_KN_PHASE_COMPLETE, "failed to perform reduce", - task->reduce_scatter_kn.etask); - } +UCC_KN_PHASE_COMPLETE: UCC_KN_PHASE_PROXY: /* unused label */ out: UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_reduce_scatter_kn_done", @@ -263,9 +235,11 @@ ucc_tl_ucp_reduce_scatter_knomial_finalize(ucc_coll_task_t *coll_task) return ucc_tl_ucp_coll_finalize(coll_task); } -ucc_status_t ucc_tl_ucp_reduce_scatter_knomial_init_r( - ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, - ucc_coll_task_t **task_h, ucc_kn_radix_t radix) +ucc_status_t +ucc_tl_ucp_reduce_scatter_knomial_init_r(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h, + ucc_kn_radix_t radix) { ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); ucc_rank_t rank = UCC_TL_TEAM_RANK(tl_team); @@ -318,8 +292,8 @@ ucc_status_t ucc_tl_ucp_reduce_scatter_knomial_init_r( ucc_status_t ucc_tl_ucp_reduce_scatter_knomial_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t * team, - ucc_coll_task_t ** task_h) + ucc_base_team_t *team, + ucc_coll_task_t **task_h) { ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); ucc_rank_t size = UCC_TL_TEAM_SIZE(tl_team);