Skip to content

Commit

Permalink
TOOLS: add persistent colls to perftest
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Nov 6, 2023
1 parent 8a7b494 commit e604395
Show file tree
Hide file tree
Showing 18 changed files with 221 additions and 84 deletions.
6 changes: 6 additions & 0 deletions src/utils/ucc_coll_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,12 @@ void ucc_coll_args_str(const ucc_coll_args_t *args, ucc_rank_t trank,
strncat(hdr, tmp, left);
}

if (UCC_IS_PERSISTENT(*args)) {
ucc_snprintf_safe(tmp, sizeof(tmp), " persistent");
left = COLL_ARGS_HEADER_STR_MAX_SIZE - strlen(hdr);
strncat(hdr, tmp, left);
}

if (ucc_coll_args_is_rooted(ct)) {
ucc_snprintf_safe(tmp, sizeof(tmp), " root %u", root);
left = COLL_ARGS_HEADER_STR_MAX_SIZE - strlen(hdr);
Expand Down
63 changes: 44 additions & 19 deletions tools/perf/ucc_pt_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,54 +18,61 @@ ucc_pt_benchmark::ucc_pt_benchmark(ucc_pt_benchmark_config cfg,
{
switch (cfg.op_type) {
case UCC_PT_OP_TYPE_ALLGATHER:
coll = new ucc_pt_coll_allgather(cfg.dt, cfg.mt, cfg.inplace, comm);
coll = new ucc_pt_coll_allgather(cfg.dt, cfg.mt, cfg.inplace,
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_ALLGATHERV:
coll = new ucc_pt_coll_allgatherv(cfg.dt, cfg.mt, cfg.inplace, comm);
coll = new ucc_pt_coll_allgatherv(cfg.dt, cfg.mt, cfg.inplace,
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_ALLREDUCE:
coll = new ucc_pt_coll_allreduce(cfg.dt, cfg.mt, cfg.op, cfg.inplace,
comm);
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_ALLTOALL:
coll = new ucc_pt_coll_alltoall(cfg.dt, cfg.mt, cfg.inplace, comm);
coll = new ucc_pt_coll_alltoall(cfg.dt, cfg.mt, cfg.inplace,
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_ALLTOALLV:
coll = new ucc_pt_coll_alltoallv(cfg.dt, cfg.mt, cfg.inplace, comm);
coll = new ucc_pt_coll_alltoallv(cfg.dt, cfg.mt, cfg.inplace,
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_BARRIER:
coll = new ucc_pt_coll_barrier(comm);
break;
case UCC_PT_OP_TYPE_BCAST:
coll = new ucc_pt_coll_bcast(cfg.dt, cfg.mt, cfg.root_shift, comm);
coll = new ucc_pt_coll_bcast(cfg.dt, cfg.mt, cfg.root_shift,
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_GATHER:
coll = new ucc_pt_coll_gather(cfg.dt, cfg.mt, cfg.inplace,
cfg.root_shift, comm);
cfg.persistent, cfg.root_shift, comm);
break;
case UCC_PT_OP_TYPE_GATHERV:
coll = new ucc_pt_coll_gatherv(cfg.dt, cfg.mt, cfg.inplace,
cfg.root_shift, comm);
cfg.persistent, cfg.root_shift, comm);
break;
case UCC_PT_OP_TYPE_REDUCE:
coll = new ucc_pt_coll_reduce(cfg.dt, cfg.mt, cfg.op, cfg.inplace,
cfg.root_shift, comm);
cfg.persistent, cfg.root_shift, comm);
break;
case UCC_PT_OP_TYPE_REDUCE_SCATTER:
coll = new ucc_pt_coll_reduce_scatter(cfg.dt, cfg.mt, cfg.op,
cfg.inplace, comm);
cfg.inplace,
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_REDUCE_SCATTERV:
coll = new ucc_pt_coll_reduce_scatterv(cfg.dt, cfg.mt, cfg.op,
cfg.inplace, comm);
cfg.inplace, cfg.persistent,
comm);
break;
case UCC_PT_OP_TYPE_SCATTER:
coll = new ucc_pt_coll_scatter(cfg.dt, cfg.mt, cfg.inplace,
cfg.root_shift, comm);
cfg.persistent, cfg.root_shift, comm);
break;
case UCC_PT_OP_TYPE_SCATTERV:
coll = new ucc_pt_coll_scatterv(cfg.dt, cfg.mt, cfg.inplace,
cfg.root_shift, comm);
cfg.persistent, cfg.root_shift, comm);
break;
case UCC_PT_OP_TYPE_MEMCPY:
coll = new ucc_pt_op_memcpy(cfg.dt, cfg.mt, cfg.n_bufs, comm);
Expand Down Expand Up @@ -137,10 +144,11 @@ ucc_status_t ucc_pt_benchmark::run_single_coll_test(ucc_coll_args_t args,
double &time)
noexcept
{
const bool triggered = config.triggered;
ucc_team_h team = comm->get_team();
ucc_context_h ctx = comm->get_context();
ucc_status_t st = UCC_OK;
const bool triggered = config.triggered;
const bool persistent = config.persistent;
ucc_team_h team = comm->get_team();
ucc_context_h ctx = comm->get_context();
ucc_status_t st = UCC_OK;
ucc_coll_req_h req;
ucc_ee_h ee;
ucc_ev_t comp_ev, *post_ev;
Expand All @@ -161,10 +169,18 @@ ucc_status_t ucc_pt_benchmark::run_single_coll_test(ucc_coll_args_t args,
comp_ev.ev_context_size = 0;
}

if (persistent) {
UCCCHECK_GOTO(ucc_collective_init(&args, &req, team), exit_err, st);
}

args.root = config.root % comm->get_size();
for (int i = 0; i < nwarmup + niter; i++) {
double s = get_time_us();
UCCCHECK_GOTO(ucc_collective_init(&args, &req, team), exit_err, st);

if (!persistent) {
UCCCHECK_GOTO(ucc_collective_init(&args, &req, team), exit_err, st);
}

if (triggered) {
comp_ev.req = req;
UCCCHECK_GOTO(ucc_collective_triggered_post(ee, &comp_ev),
Expand All @@ -175,12 +191,16 @@ ucc_status_t ucc_pt_benchmark::run_single_coll_test(ucc_coll_args_t args,
} else {
UCCCHECK_GOTO(ucc_collective_post(req), free_req, st);
}

st = ucc_collective_test(req);
while (st > 0) {
UCCCHECK_GOTO(ucc_context_progress(ctx), free_req, st);
st = ucc_collective_test(req);
}
ucc_collective_finalize(req);

if (!persistent) {
ucc_collective_finalize(req);
}
double f = get_time_us();
if (st != UCC_OK) {
goto exit_err;
Expand All @@ -191,6 +211,11 @@ ucc_status_t ucc_pt_benchmark::run_single_coll_test(ucc_coll_args_t args,
args.root = (args.root + config.root_shift) % comm->get_size();
UCCCHECK_GOTO(comm->barrier(), exit_err, st);
}

if (persistent) {
ucc_collective_finalize(req);
}

if (niter != 0) {
time /= niter;
}
Expand Down
32 changes: 18 additions & 14 deletions tools/perf/ucc_pt_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class ucc_pt_coll {
class ucc_pt_coll_allgather: public ucc_pt_coll {
public:
ucc_pt_coll_allgather(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, ucc_pt_comm *communicator);
bool is_inplace, bool is_persistent,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override;
Expand All @@ -67,7 +68,8 @@ class ucc_pt_coll_allgather: public ucc_pt_coll {
class ucc_pt_coll_allgatherv: public ucc_pt_coll {
public:
ucc_pt_coll_allgatherv(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, ucc_pt_comm *communicator);
bool is_inplace, bool is_persistent,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
};
Expand All @@ -76,7 +78,7 @@ class ucc_pt_coll_allreduce: public ucc_pt_coll {
public:
ucc_pt_coll_allreduce(ucc_datatype_t dt, ucc_memory_type mt,
ucc_reduction_op_t op, bool is_inplace,
ucc_pt_comm *communicator);
bool is_persistent, ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override;
Expand All @@ -85,7 +87,8 @@ class ucc_pt_coll_allreduce: public ucc_pt_coll {
class ucc_pt_coll_alltoall: public ucc_pt_coll {
public:
ucc_pt_coll_alltoall(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, ucc_pt_comm *communicator);
bool is_inplace, bool is_persistent,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override;
Expand All @@ -94,7 +97,8 @@ class ucc_pt_coll_alltoall: public ucc_pt_coll {
class ucc_pt_coll_alltoallv: public ucc_pt_coll {
public:
ucc_pt_coll_alltoallv(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, ucc_pt_comm *communicator);
bool is_inplace, bool is_persistent,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
};
Expand All @@ -109,7 +113,7 @@ class ucc_pt_coll_barrier: public ucc_pt_coll {
class ucc_pt_coll_bcast: public ucc_pt_coll {
public:
ucc_pt_coll_bcast(ucc_datatype_t dt, ucc_memory_type mt, int root_shift,
ucc_pt_comm *communicator);
bool is_persistent, ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override;
Expand All @@ -118,7 +122,7 @@ class ucc_pt_coll_bcast: public ucc_pt_coll {
class ucc_pt_coll_gather: public ucc_pt_coll {
public:
ucc_pt_coll_gather(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, int root_shift,
bool is_inplace, bool is_persistent, int root_shift,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
Expand All @@ -128,7 +132,7 @@ class ucc_pt_coll_gather: public ucc_pt_coll {
class ucc_pt_coll_gatherv: public ucc_pt_coll {
public:
ucc_pt_coll_gatherv(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, int root_shift,
bool is_inplace, bool is_persistent, int root_shift,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
Expand All @@ -137,8 +141,8 @@ class ucc_pt_coll_gatherv: public ucc_pt_coll {
class ucc_pt_coll_reduce: public ucc_pt_coll {
public:
ucc_pt_coll_reduce(ucc_datatype_t dt, ucc_memory_type mt,
ucc_reduction_op_t op, bool is_inplace, int root_shift,
ucc_pt_comm *communicator);
ucc_reduction_op_t op, bool is_inplace, bool is_persistent,
int root_shift, ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override;
Expand All @@ -148,7 +152,7 @@ class ucc_pt_coll_reduce_scatter: public ucc_pt_coll {
public:
ucc_pt_coll_reduce_scatter(ucc_datatype_t dt, ucc_memory_type mt,
ucc_reduction_op_t op, bool is_inplace,
ucc_pt_comm *communicator);
bool is_persistent, ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override;
Expand All @@ -158,15 +162,15 @@ class ucc_pt_coll_reduce_scatterv: public ucc_pt_coll {
public:
ucc_pt_coll_reduce_scatterv(ucc_datatype_t dt, ucc_memory_type mt,
ucc_reduction_op_t op, bool is_inplace,
ucc_pt_comm *communicator);
bool is_persistent, ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
};

class ucc_pt_coll_scatter: public ucc_pt_coll {
public:
ucc_pt_coll_scatter(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, int root_shift,
bool is_inplace, bool is_persistent, int root_shift,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
Expand All @@ -176,7 +180,7 @@ class ucc_pt_coll_scatter: public ucc_pt_coll {
class ucc_pt_coll_scatterv: public ucc_pt_coll {
public:
ucc_pt_coll_scatterv(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, int root_shift,
bool is_inplace, bool is_persistent, int root_shift,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
Expand Down
14 changes: 11 additions & 3 deletions tools/perf/ucc_pt_coll_allgather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

ucc_pt_coll_allgather::ucc_pt_coll_allgather(ucc_datatype_t dt,
ucc_memory_type mt, bool is_inplace,
bool is_persistent,
ucc_pt_comm *communicator) : ucc_pt_coll(communicator)

{
Expand All @@ -21,16 +22,23 @@ ucc_pt_coll_allgather::ucc_pt_coll_allgather(ucc_datatype_t dt,
has_bw_ = true;
root_shift_ = 0;

coll_args.mask = 0;
coll_args.coll_type = UCC_COLL_TYPE_ALLGATHER;
coll_args.mask = 0;
coll_args.flags = 0;
coll_args.coll_type = UCC_COLL_TYPE_ALLGATHER;
coll_args.src.info.datatype = dt;
coll_args.src.info.mem_type = mt;
coll_args.dst.info.datatype = dt;
coll_args.dst.info.mem_type = mt;

if (is_inplace) {
coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}

if (is_persistent) {
coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT;
}
}

ucc_status_t ucc_pt_coll_allgather::init_args(size_t single_rank_count,
Expand Down
18 changes: 13 additions & 5 deletions tools/perf/ucc_pt_coll_allgatherv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

ucc_pt_coll_allgatherv::ucc_pt_coll_allgatherv(ucc_datatype_t dt,
ucc_memory_type mt, bool is_inplace,
bool is_persistent,
ucc_pt_comm *communicator) : ucc_pt_coll(communicator)
{
has_inplace_ = true;
Expand All @@ -20,16 +21,23 @@ ucc_pt_coll_allgatherv::ucc_pt_coll_allgatherv(ucc_datatype_t dt,
has_bw_ = false;
root_shift_ = 0;

coll_args.mask = 0;
coll_args.coll_type = UCC_COLL_TYPE_ALLGATHERV;
coll_args.src.info.datatype = dt;
coll_args.src.info.mem_type = mt;
coll_args.mask = 0;
coll_args.flags = 0;
coll_args.coll_type = UCC_COLL_TYPE_ALLGATHERV;
coll_args.src.info.datatype = dt;
coll_args.src.info.mem_type = mt;
coll_args.dst.info_v.datatype = dt;
coll_args.dst.info_v.mem_type = mt;

if (is_inplace) {
coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}

if (is_persistent) {
coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT;
}
}

ucc_status_t ucc_pt_coll_allgatherv::init_args(size_t count,
Expand Down
22 changes: 15 additions & 7 deletions tools/perf/ucc_pt_coll_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

ucc_pt_coll_allreduce::ucc_pt_coll_allreduce(ucc_datatype_t dt,
ucc_memory_type mt, ucc_reduction_op_t op,
bool is_inplace,
bool is_inplace, bool is_persistent,
ucc_pt_comm *communicator) : ucc_pt_coll(communicator)
{
has_inplace_ = true;
Expand All @@ -21,17 +21,25 @@ ucc_pt_coll_allreduce::ucc_pt_coll_allreduce(ucc_datatype_t dt,
has_bw_ = true;
root_shift_ = 0;

coll_args.coll_type = UCC_COLL_TYPE_ALLREDUCE;
coll_args.mask = 0;
if (is_inplace) {
coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}
coll_args.mask = 0;
coll_args.flags = 0;
coll_args.coll_type = UCC_COLL_TYPE_ALLREDUCE;
coll_args.op = op;
coll_args.src.info.datatype = dt;
coll_args.dst.info.datatype = dt;
coll_args.src.info.mem_type = mt;
coll_args.dst.info.mem_type = mt;

if (is_inplace) {
coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}

if (is_persistent) {
coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT;

}
}

ucc_status_t ucc_pt_coll_allreduce::init_args(size_t count,
Expand Down
Loading

0 comments on commit e604395

Please sign in to comment.