Skip to content

Commit

Permalink
TL/UCP: add bruck alltoall (#756)
Browse files Browse the repository at this point in the history
* TL/UCP: add bruck alltoall

* REVIEW: fix review comments

* REVIEW: fix review comments
  • Loading branch information
Sergei-Lebedev authored May 8, 2023
1 parent 9a2f2c8 commit a036a5f
Show file tree
Hide file tree
Showing 11 changed files with 413 additions and 50 deletions.
1 change: 1 addition & 0 deletions src/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ noinst_HEADERS = \
components/ec/ucc_ec_log.h \
coll_patterns/recursive_knomial.h \
coll_patterns/sra_knomial.h \
coll_patterns/bruck_alltoall.h \
components/topo/ucc_topo.h \
components/topo/ucc_sbgp.h

Expand Down
41 changes: 41 additions & 0 deletions src/coll_patterns/bruck_alltoall.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#ifndef BRUCK_ALLTOALL_H_
#define BRUCK_ALLTOALL_H_

#include "utils/ucc_math.h"

#define GET_NEXT_BRUCK_NUM(_num, _radix, _pow) \
((((_num) + 1) % (_pow))?((_num) + 1):(((_num) + 1) + (_pow) * ((_radix) - 1)))

#define GET_PREV_BRUCK_NUM(_num, _radix, _pow) \
(((_num) % (_pow))?((_num) - 1):(((_num) - 1) - (_pow) * ((_radix) - 1)))

static inline ucc_rank_t get_bruck_step_start(uint32_t pow, uint32_t d)
{
return pow * d;
}

static inline ucc_rank_t get_bruck_step_finish(ucc_rank_t n, uint32_t radix,
uint32_t d, uint32_t pow)
{
return ucc_min(n + pow - 1 - (n - d * pow) % (pow * radix), n);
}

static inline ucc_rank_t get_bruck_recv_peer(ucc_rank_t trank, ucc_rank_t tsize,
ucc_rank_t step, uint32_t digit)
{
return (trank - step * digit + tsize * digit) % tsize;
}

static inline ucc_rank_t get_bruck_send_peer(ucc_rank_t trank, ucc_rank_t tsize,
ucc_rank_t step, uint32_t digit)
{
return (trank + step * digit) % tsize;
}

#endif
3 changes: 2 additions & 1 deletion src/components/tl/ucp/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ alltoall = \
alltoall/alltoall.h \
alltoall/alltoall.c \
alltoall/alltoall_onesided.c \
alltoall/alltoall_pairwise.c
alltoall/alltoall_pairwise.c \
alltoall/alltoall_bruck.c

alltoallv = \
alltoallv/alltoallv.h \
Expand Down
21 changes: 20 additions & 1 deletion src/components/tl/ucp/alltoall/alltoall.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -8,18 +8,37 @@
#include "tl_ucp.h"
#include "alltoall.h"

#define ALLTOALL_MAX_PATTERN_SIZE (sizeof(UCC_TL_UCP_ALLTOALL_DEFAULT_ALG_SELECT_STR_PATTERN) + 32)
#define ALLTOALL_DEFAULT_ALG_SWITCH 129

ucc_status_t ucc_tl_ucp_alltoall_pairwise_start(ucc_coll_task_t *task);
void ucc_tl_ucp_alltoall_pairwise_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *task);
void ucc_tl_ucp_alltoall_onesided_progress(ucc_coll_task_t *task);

char* ucc_tl_ucp_alltoall_score_str_get(ucc_tl_ucp_team_t *team)
{
int max_size = ALLTOALL_MAX_PATTERN_SIZE;
char *str;

str = ucc_malloc(max_size * sizeof(char));
ucc_snprintf_safe(str, max_size,
UCC_TL_UCP_ALLTOALL_DEFAULT_ALG_SELECT_STR_PATTERN,
ALLTOALL_DEFAULT_ALG_SWITCH * UCC_TL_TEAM_SIZE(team));
return str;
}

ucc_base_coll_alg_info_t
ucc_tl_ucp_alltoall_algs[UCC_TL_UCP_ALLTOALL_ALG_LAST + 1] = {
[UCC_TL_UCP_ALLTOALL_ALG_PAIRWISE] =
{.id = UCC_TL_UCP_ALLTOALL_ALG_PAIRWISE,
.name = "pairwise",
.desc = "pairwise two-sided implementation"},
[UCC_TL_UCP_ALLTOALL_ALG_BRUCK] =
{.id = UCC_TL_UCP_ALLTOALL_ALG_BRUCK,
.name = "bruck",
.desc = "Bruck alltoall"},
[UCC_TL_UCP_ALLTOALL_ALG_ONESIDED] =
{.id = UCC_TL_UCP_ALLTOALL_ALG_ONESIDED,
.name = "onesided",
Expand Down
23 changes: 16 additions & 7 deletions src/components/tl/ucp/alltoall/alltoall.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -12,26 +12,35 @@

enum {
UCC_TL_UCP_ALLTOALL_ALG_PAIRWISE,
UCC_TL_UCP_ALLTOALL_ALG_BRUCK,
UCC_TL_UCP_ALLTOALL_ALG_ONESIDED,
UCC_TL_UCP_ALLTOALL_ALG_LAST
};

extern ucc_base_coll_alg_info_t
ucc_tl_ucp_alltoall_algs[UCC_TL_UCP_ALLTOALL_ALG_LAST + 1];

#define UCC_TL_UCP_ALLTOALL_DEFAULT_ALG_SELECT_STR "alltoall:0-inf:@0"
#define UCC_TL_UCP_ALLTOALL_DEFAULT_ALG_SELECT_STR_PATTERN \
"alltoall:host:0-%d:@bruck"

char* ucc_tl_ucp_alltoall_score_str_get(ucc_tl_ucp_team_t *team);

ucc_status_t ucc_tl_ucp_alltoall_init(ucc_tl_ucp_task_t *task);

ucc_status_t ucc_tl_ucp_alltoall_pairwise_init_common(ucc_tl_ucp_task_t *task);

ucc_status_t ucc_tl_ucp_alltoall_pairwise_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_alltoall_bruck_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_alltoall_pairwise_init_common(ucc_tl_ucp_task_t *task);

ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h);
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

#define ALLTOALL_CHECK_INPLACE(_args, _team) \
do { \
Expand Down
230 changes: 230 additions & 0 deletions src/components/tl/ucp/alltoall/alltoall_bruck.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/**
* Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "config.h"
#include "tl_ucp.h"
#include "alltoall.h"
#include "tl_ucp_sendrecv.h"
#include "components/mc/ucc_mc.h"
#include "coll_patterns/bruck_alltoall.h"

#define RADIX 2

enum {
PHASE_MERGE,
PHASE_SENDRECV
};

static inline int msb_pos_for_level(unsigned int nthbit, ucc_rank_t number)
{

int msb_set = -1;
unsigned int i;

for (i = 0; i < nthbit - 1; i++) {
if (1 & number >> i) {
msb_set = i;
}
}

return msb_set;
}

static inline int find_seg_index(ucc_rank_t seg_index, int level, int nsegs_per_rblock)
{
int block, blockseg;

if (0 == seg_index) {
return -1;
}

block = msb_pos_for_level(level, seg_index);

if (block < 0) {
return -1;
}

/* remove block bit from seg_index */
blockseg = ((seg_index >> (block + 1)) << block) |
(seg_index & UCC_MASK(block));
return block * nsegs_per_rblock + blockseg;
}

ucc_status_t ucc_tl_ucp_alltoall_bruck_backward_rotation(void *dst, void *src,
ucc_rank_t trank,
ucc_rank_t tsize,
size_t seg_size)
{
ucc_status_t st;
ucc_rank_t index, level, nsegs_per_rblock;
size_t snd_offset;
int send_buffer_index;

level = lognum(tsize);
nsegs_per_rblock = tsize / 2;
for (index = 1; index < tsize; index++) {
send_buffer_index = find_seg_index(index, level + 1, nsegs_per_rblock);
ucc_assert(send_buffer_index >= 0);
snd_offset = send_buffer_index * seg_size;
st = ucc_mc_memcpy(PTR_OFFSET(dst, seg_size *
((trank - index + tsize) % tsize)),
PTR_OFFSET(src, snd_offset), seg_size,
UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_HOST);
if (ucc_unlikely(st != UCC_OK)) {
return st;
}
}

return UCC_OK;
}

ucc_status_t ucc_tl_ucp_alltoall_bruck_finalize(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_status_t st, global_st;

global_st = ucc_mc_free(task->alltoall_bruck.scratch_mc_header);
if (ucc_unlikely(global_st != UCC_OK)) {
tl_error(UCC_TASK_LIB(task), "failed to free scratch buffer");
}

st = ucc_tl_ucp_coll_finalize(&task->super);
if (ucc_unlikely(st != UCC_OK)) {
tl_error(UCC_TASK_LIB(task), "failed finalize collective");
global_st = st;
}
return global_st;
}

void ucc_tl_ucp_alltoall_bruck_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
ucc_coll_args_t *args = &TASK_ARGS(task);
void *scratch = task->alltoall_bruck.scratch_mc_header->addr;
void *mergebuf = args->dst.info.buffer;
const ucc_rank_t nrecv_segs = tsize / 2;
const size_t seg_size = ucc_dt_size(args->src.info.datatype) *
args->src.info.count / tsize;
void *data;
ucc_rank_t sendto, recvfrom, step, index;
ucc_rank_t level, snd_count;
int send_buffer_index;
ucc_status_t st;

if (task->alltoall_bruck.phase == PHASE_SENDRECV) {
goto ALLTOALL_BRUCK_PHASE_SENDRECV;
}

step = 1 << (task->alltoall_bruck.iteration - 1);
while (step < tsize) {
level = task->alltoall_bruck.iteration - 1;
sendto = get_bruck_send_peer(trank, tsize, step, 1);
recvfrom = get_bruck_recv_peer(trank, tsize, step, 1);

snd_count = 0;
for (index = get_bruck_step_start(step, 1);
index <= get_bruck_step_finish(tsize - 1, RADIX, 1, step);
index = GET_NEXT_BRUCK_NUM(index, RADIX, step)) {
send_buffer_index = find_seg_index(index, level + 1, nrecv_segs);
if (send_buffer_index == -1) {
data = PTR_OFFSET(args->src.info.buffer,
((index + trank) % tsize) * seg_size);
} else {
data = PTR_OFFSET(scratch, send_buffer_index * seg_size);
}
ucc_mc_memcpy(PTR_OFFSET(mergebuf, seg_size * snd_count),
data, seg_size, UCC_MEMORY_TYPE_HOST,
UCC_MEMORY_TYPE_HOST);
snd_count++;
}
data = PTR_OFFSET(scratch, level * nrecv_segs * seg_size);
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(mergebuf, snd_count * seg_size,
UCC_MEMORY_TYPE_HOST, sendto, team,
task),
task, out);
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(data, snd_count * seg_size,
UCC_MEMORY_TYPE_HOST, recvfrom, team,
task),
task, out);
ALLTOALL_BRUCK_PHASE_SENDRECV:
if (ucc_tl_ucp_test(task) == UCC_INPROGRESS) {
task->alltoall_bruck.phase = PHASE_SENDRECV;
return;
}
task->alltoall_bruck.iteration++;
step = 1 << (task->alltoall_bruck.iteration - 1);
}

st = ucc_mc_memcpy(PTR_OFFSET(args->dst.info.buffer, trank * seg_size),
PTR_OFFSET(args->src.info.buffer, trank * seg_size),
seg_size, UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_HOST);
if (ucc_unlikely(st != UCC_OK)) {
task->super.status = st;
return;
}
task->super.status =
ucc_tl_ucp_alltoall_bruck_backward_rotation(args->dst.info.buffer,
scratch, trank, tsize,
seg_size);
out:
return;
}

ucc_status_t ucc_tl_ucp_alltoall_bruck_start(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);

task->alltoall_bruck.iteration = 1;
task->alltoall_bruck.phase = PHASE_MERGE;
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}

ucc_status_t ucc_tl_ucp_alltoall_bruck_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(tl_team);
ucc_coll_args_t *args = &coll_args->args;
size_t seg_size = ucc_dt_size(args->src.info.datatype) *
args->src.info.count / tsize;
size_t scratch_size;
ucc_tl_ucp_task_t *task;
ucc_status_t status;

if ((coll_args->args.src.info.mem_type != UCC_MEMORY_TYPE_HOST) ||
(coll_args->args.dst.info.mem_type != UCC_MEMORY_TYPE_HOST)) {
status = UCC_ERR_NOT_SUPPORTED;
goto out;
}
ALLTOALL_TASK_CHECK(coll_args->args, tl_team);

task = ucc_tl_ucp_init_task(coll_args, team);
task->super.post = ucc_tl_ucp_alltoall_bruck_start;
task->super.progress = ucc_tl_ucp_alltoall_bruck_progress;
task->super.finalize = ucc_tl_ucp_alltoall_bruck_finalize;

scratch_size = lognum(tsize) * ucc_div_round_up(tsize, 2) * seg_size;
status = ucc_mc_alloc(&task->alltoall_bruck.scratch_mc_header,
scratch_size, UCC_MEMORY_TYPE_HOST);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task), "failed to allocate scratch buffer");
ucc_tl_ucp_coll_finalize(&task->super);
return status;
}

*task_h = &task->super;
return UCC_OK;

out:
return status;
}
Loading

0 comments on commit a036a5f

Please sign in to comment.