diff --git a/src/utils/ucc_coll_utils.c b/src/utils/ucc_coll_utils.c index 3921f1262e..fbf7b4b452 100644 --- a/src/utils/ucc_coll_utils.c +++ b/src/utils/ucc_coll_utils.c @@ -359,6 +359,12 @@ void ucc_coll_args_str(const ucc_coll_args_t *args, ucc_rank_t trank, strncat(hdr, tmp, left); } + if (UCC_IS_PERSISTENT(*args)) { + ucc_snprintf_safe(tmp, sizeof(tmp), " persistent"); + left = COLL_ARGS_HEADER_STR_MAX_SIZE - strlen(hdr); + strncat(hdr, tmp, left); + } + if (ucc_coll_args_is_rooted(ct)) { ucc_snprintf_safe(tmp, sizeof(tmp), " root %u", root); left = COLL_ARGS_HEADER_STR_MAX_SIZE - strlen(hdr); diff --git a/tools/perf/ucc_pt_benchmark.cc b/tools/perf/ucc_pt_benchmark.cc index c4ef8c6289..cbaa5d664a 100644 --- a/tools/perf/ucc_pt_benchmark.cc +++ b/tools/perf/ucc_pt_benchmark.cc @@ -18,54 +18,61 @@ ucc_pt_benchmark::ucc_pt_benchmark(ucc_pt_benchmark_config cfg, { switch (cfg.op_type) { case UCC_PT_OP_TYPE_ALLGATHER: - coll = new ucc_pt_coll_allgather(cfg.dt, cfg.mt, cfg.inplace, comm); + coll = new ucc_pt_coll_allgather(cfg.dt, cfg.mt, cfg.inplace, + cfg.persistent, comm); break; case UCC_PT_OP_TYPE_ALLGATHERV: - coll = new ucc_pt_coll_allgatherv(cfg.dt, cfg.mt, cfg.inplace, comm); + coll = new ucc_pt_coll_allgatherv(cfg.dt, cfg.mt, cfg.inplace, + cfg.persistent, comm); break; case UCC_PT_OP_TYPE_ALLREDUCE: coll = new ucc_pt_coll_allreduce(cfg.dt, cfg.mt, cfg.op, cfg.inplace, - comm); + cfg.persistent, comm); break; case UCC_PT_OP_TYPE_ALLTOALL: - coll = new ucc_pt_coll_alltoall(cfg.dt, cfg.mt, cfg.inplace, comm); + coll = new ucc_pt_coll_alltoall(cfg.dt, cfg.mt, cfg.inplace, + cfg.persistent, comm); break; case UCC_PT_OP_TYPE_ALLTOALLV: - coll = new ucc_pt_coll_alltoallv(cfg.dt, cfg.mt, cfg.inplace, comm); + coll = new ucc_pt_coll_alltoallv(cfg.dt, cfg.mt, cfg.inplace, + cfg.persistent, comm); break; case UCC_PT_OP_TYPE_BARRIER: coll = new ucc_pt_coll_barrier(comm); break; case UCC_PT_OP_TYPE_BCAST: - coll = new ucc_pt_coll_bcast(cfg.dt, cfg.mt, cfg.root_shift, comm); + coll = new ucc_pt_coll_bcast(cfg.dt, cfg.mt, cfg.root_shift, + cfg.persistent, comm); break; case UCC_PT_OP_TYPE_GATHER: coll = new ucc_pt_coll_gather(cfg.dt, cfg.mt, cfg.inplace, - cfg.root_shift, comm); + cfg.persistent, cfg.root_shift, comm); break; case UCC_PT_OP_TYPE_GATHERV: coll = new ucc_pt_coll_gatherv(cfg.dt, cfg.mt, cfg.inplace, - cfg.root_shift, comm); + cfg.persistent, cfg.root_shift, comm); break; case UCC_PT_OP_TYPE_REDUCE: coll = new ucc_pt_coll_reduce(cfg.dt, cfg.mt, cfg.op, cfg.inplace, - cfg.root_shift, comm); + cfg.persistent, cfg.root_shift, comm); break; case UCC_PT_OP_TYPE_REDUCE_SCATTER: coll = new ucc_pt_coll_reduce_scatter(cfg.dt, cfg.mt, cfg.op, - cfg.inplace, comm); + cfg.inplace, + cfg.persistent, comm); break; case UCC_PT_OP_TYPE_REDUCE_SCATTERV: coll = new ucc_pt_coll_reduce_scatterv(cfg.dt, cfg.mt, cfg.op, - cfg.inplace, comm); + cfg.inplace, cfg.persistent, + comm); break; case UCC_PT_OP_TYPE_SCATTER: coll = new ucc_pt_coll_scatter(cfg.dt, cfg.mt, cfg.inplace, - cfg.root_shift, comm); + cfg.persistent, cfg.root_shift, comm); break; case UCC_PT_OP_TYPE_SCATTERV: coll = new ucc_pt_coll_scatterv(cfg.dt, cfg.mt, cfg.inplace, - cfg.root_shift, comm); + cfg.persistent, cfg.root_shift, comm); break; case UCC_PT_OP_TYPE_MEMCPY: coll = new ucc_pt_op_memcpy(cfg.dt, cfg.mt, cfg.n_bufs, comm); @@ -137,10 +144,11 @@ ucc_status_t ucc_pt_benchmark::run_single_coll_test(ucc_coll_args_t args, double &time) noexcept { - const bool triggered = config.triggered; - ucc_team_h team = comm->get_team(); - ucc_context_h ctx = comm->get_context(); - ucc_status_t st = UCC_OK; + const bool triggered = config.triggered; + const bool persistent = config.persistent; + ucc_team_h team = comm->get_team(); + ucc_context_h ctx = comm->get_context(); + ucc_status_t st = UCC_OK; ucc_coll_req_h req; ucc_ee_h ee; ucc_ev_t comp_ev, *post_ev; @@ -161,10 +169,18 @@ ucc_status_t ucc_pt_benchmark::run_single_coll_test(ucc_coll_args_t args, comp_ev.ev_context_size = 0; } + if (persistent) { + UCCCHECK_GOTO(ucc_collective_init(&args, &req, team), exit_err, st); + } + args.root = config.root % comm->get_size(); for (int i = 0; i < nwarmup + niter; i++) { double s = get_time_us(); - UCCCHECK_GOTO(ucc_collective_init(&args, &req, team), exit_err, st); + + if (!persistent) { + UCCCHECK_GOTO(ucc_collective_init(&args, &req, team), exit_err, st); + } + if (triggered) { comp_ev.req = req; UCCCHECK_GOTO(ucc_collective_triggered_post(ee, &comp_ev), @@ -175,12 +191,16 @@ ucc_status_t ucc_pt_benchmark::run_single_coll_test(ucc_coll_args_t args, } else { UCCCHECK_GOTO(ucc_collective_post(req), free_req, st); } + st = ucc_collective_test(req); while (st > 0) { UCCCHECK_GOTO(ucc_context_progress(ctx), free_req, st); st = ucc_collective_test(req); } - ucc_collective_finalize(req); + + if (!persistent) { + ucc_collective_finalize(req); + } double f = get_time_us(); if (st != UCC_OK) { goto exit_err; @@ -191,6 +211,11 @@ ucc_status_t ucc_pt_benchmark::run_single_coll_test(ucc_coll_args_t args, args.root = (args.root + config.root_shift) % comm->get_size(); UCCCHECK_GOTO(comm->barrier(), exit_err, st); } + + if (persistent) { + ucc_collective_finalize(req); + } + if (niter != 0) { time /= niter; } diff --git a/tools/perf/ucc_pt_coll.h b/tools/perf/ucc_pt_coll.h index 63afc9bd9e..0b92039fab 100644 --- a/tools/perf/ucc_pt_coll.h +++ b/tools/perf/ucc_pt_coll.h @@ -58,7 +58,8 @@ class ucc_pt_coll { class ucc_pt_coll_allgather: public ucc_pt_coll { public: ucc_pt_coll_allgather(ucc_datatype_t dt, ucc_memory_type mt, - bool is_inplace, ucc_pt_comm *communicator); + bool is_inplace, bool is_persistent, + ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; @@ -67,7 +68,8 @@ class ucc_pt_coll_allgather: public ucc_pt_coll { class ucc_pt_coll_allgatherv: public ucc_pt_coll { public: ucc_pt_coll_allgatherv(ucc_datatype_t dt, ucc_memory_type mt, - bool is_inplace, ucc_pt_comm *communicator); + bool is_inplace, bool is_persistent, + ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; }; @@ -76,7 +78,7 @@ class ucc_pt_coll_allreduce: public ucc_pt_coll { public: ucc_pt_coll_allreduce(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, bool is_inplace, - ucc_pt_comm *communicator); + bool is_persistent, ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; @@ -85,7 +87,8 @@ class ucc_pt_coll_allreduce: public ucc_pt_coll { class ucc_pt_coll_alltoall: public ucc_pt_coll { public: ucc_pt_coll_alltoall(ucc_datatype_t dt, ucc_memory_type mt, - bool is_inplace, ucc_pt_comm *communicator); + bool is_inplace, bool is_persistent, + ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; @@ -94,7 +97,8 @@ class ucc_pt_coll_alltoall: public ucc_pt_coll { class ucc_pt_coll_alltoallv: public ucc_pt_coll { public: ucc_pt_coll_alltoallv(ucc_datatype_t dt, ucc_memory_type mt, - bool is_inplace, ucc_pt_comm *communicator); + bool is_inplace, bool is_persistent, + ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; }; @@ -109,7 +113,7 @@ class ucc_pt_coll_barrier: public ucc_pt_coll { class ucc_pt_coll_bcast: public ucc_pt_coll { public: ucc_pt_coll_bcast(ucc_datatype_t dt, ucc_memory_type mt, int root_shift, - ucc_pt_comm *communicator); + bool is_persistent, ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; @@ -118,7 +122,7 @@ class ucc_pt_coll_bcast: public ucc_pt_coll { class ucc_pt_coll_gather: public ucc_pt_coll { public: ucc_pt_coll_gather(ucc_datatype_t dt, ucc_memory_type mt, - bool is_inplace, int root_shift, + bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; @@ -128,7 +132,7 @@ class ucc_pt_coll_gather: public ucc_pt_coll { class ucc_pt_coll_gatherv: public ucc_pt_coll { public: ucc_pt_coll_gatherv(ucc_datatype_t dt, ucc_memory_type mt, - bool is_inplace, int root_shift, + bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; @@ -137,8 +141,8 @@ class ucc_pt_coll_gatherv: public ucc_pt_coll { class ucc_pt_coll_reduce: public ucc_pt_coll { public: ucc_pt_coll_reduce(ucc_datatype_t dt, ucc_memory_type mt, - ucc_reduction_op_t op, bool is_inplace, int root_shift, - ucc_pt_comm *communicator); + ucc_reduction_op_t op, bool is_inplace, bool is_persistent, + int root_shift, ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; @@ -148,7 +152,7 @@ class ucc_pt_coll_reduce_scatter: public ucc_pt_coll { public: ucc_pt_coll_reduce_scatter(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, bool is_inplace, - ucc_pt_comm *communicator); + bool is_persistent, ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; @@ -158,7 +162,7 @@ class ucc_pt_coll_reduce_scatterv: public ucc_pt_coll { public: ucc_pt_coll_reduce_scatterv(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, bool is_inplace, - ucc_pt_comm *communicator); + bool is_persistent, ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; }; @@ -166,7 +170,7 @@ class ucc_pt_coll_reduce_scatterv: public ucc_pt_coll { class ucc_pt_coll_scatter: public ucc_pt_coll { public: ucc_pt_coll_scatter(ucc_datatype_t dt, ucc_memory_type mt, - bool is_inplace, int root_shift, + bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; @@ -176,7 +180,7 @@ class ucc_pt_coll_scatter: public ucc_pt_coll { class ucc_pt_coll_scatterv: public ucc_pt_coll { public: ucc_pt_coll_scatterv(ucc_datatype_t dt, ucc_memory_type mt, - bool is_inplace, int root_shift, + bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; diff --git a/tools/perf/ucc_pt_coll_allgather.cc b/tools/perf/ucc_pt_coll_allgather.cc index 76e6084032..b8185dd9e8 100644 --- a/tools/perf/ucc_pt_coll_allgather.cc +++ b/tools/perf/ucc_pt_coll_allgather.cc @@ -12,6 +12,7 @@ ucc_pt_coll_allgather::ucc_pt_coll_allgather(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, + bool is_persistent, ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { @@ -21,16 +22,23 @@ ucc_pt_coll_allgather::ucc_pt_coll_allgather(ucc_datatype_t dt, has_bw_ = true; root_shift_ = 0; - coll_args.mask = 0; - coll_args.coll_type = UCC_COLL_TYPE_ALLGATHER; + coll_args.mask = 0; + coll_args.flags = 0; + coll_args.coll_type = UCC_COLL_TYPE_ALLGATHER; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info.datatype = dt; coll_args.dst.info.mem_type = mt; + if (is_inplace) { - coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + + if (is_persistent) { + coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } } ucc_status_t ucc_pt_coll_allgather::init_args(size_t single_rank_count, diff --git a/tools/perf/ucc_pt_coll_allgatherv.cc b/tools/perf/ucc_pt_coll_allgatherv.cc index 8642322c64..c6c18a7c5a 100644 --- a/tools/perf/ucc_pt_coll_allgatherv.cc +++ b/tools/perf/ucc_pt_coll_allgatherv.cc @@ -12,6 +12,7 @@ ucc_pt_coll_allgatherv::ucc_pt_coll_allgatherv(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, + bool is_persistent, ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { has_inplace_ = true; @@ -20,16 +21,23 @@ ucc_pt_coll_allgatherv::ucc_pt_coll_allgatherv(ucc_datatype_t dt, has_bw_ = false; root_shift_ = 0; - coll_args.mask = 0; - coll_args.coll_type = UCC_COLL_TYPE_ALLGATHERV; - coll_args.src.info.datatype = dt; - coll_args.src.info.mem_type = mt; + coll_args.mask = 0; + coll_args.flags = 0; + coll_args.coll_type = UCC_COLL_TYPE_ALLGATHERV; + coll_args.src.info.datatype = dt; + coll_args.src.info.mem_type = mt; coll_args.dst.info_v.datatype = dt; coll_args.dst.info_v.mem_type = mt; + if (is_inplace) { - coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + + if (is_persistent) { + coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } } ucc_status_t ucc_pt_coll_allgatherv::init_args(size_t count, diff --git a/tools/perf/ucc_pt_coll_allreduce.cc b/tools/perf/ucc_pt_coll_allreduce.cc index 8234f26dd7..3159dc3a9f 100644 --- a/tools/perf/ucc_pt_coll_allreduce.cc +++ b/tools/perf/ucc_pt_coll_allreduce.cc @@ -12,7 +12,7 @@ ucc_pt_coll_allreduce::ucc_pt_coll_allreduce(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, - bool is_inplace, + bool is_inplace, bool is_persistent, ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { has_inplace_ = true; @@ -21,17 +21,25 @@ ucc_pt_coll_allreduce::ucc_pt_coll_allreduce(ucc_datatype_t dt, has_bw_ = true; root_shift_ = 0; - coll_args.coll_type = UCC_COLL_TYPE_ALLREDUCE; - coll_args.mask = 0; - if (is_inplace) { - coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; - coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } + coll_args.mask = 0; + coll_args.flags = 0; + coll_args.coll_type = UCC_COLL_TYPE_ALLREDUCE; coll_args.op = op; coll_args.src.info.datatype = dt; coll_args.dst.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info.mem_type = mt; + + if (is_inplace) { + coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; + } + + if (is_persistent) { + coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + + } } ucc_status_t ucc_pt_coll_allreduce::init_args(size_t count, diff --git a/tools/perf/ucc_pt_coll_alltoall.cc b/tools/perf/ucc_pt_coll_alltoall.cc index f4e9cf57b5..77a2608f7f 100644 --- a/tools/perf/ucc_pt_coll_alltoall.cc +++ b/tools/perf/ucc_pt_coll_alltoall.cc @@ -12,6 +12,7 @@ ucc_pt_coll_alltoall::ucc_pt_coll_alltoall(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, + bool is_persistent, ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { has_inplace_ = true; @@ -20,16 +21,23 @@ ucc_pt_coll_alltoall::ucc_pt_coll_alltoall(ucc_datatype_t dt, has_bw_ = true; root_shift_ = 0; - coll_args.mask = 0; - coll_args.coll_type = UCC_COLL_TYPE_ALLTOALL; + coll_args.mask = 0; + coll_args.flags = 0; + coll_args.coll_type = UCC_COLL_TYPE_ALLTOALL; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info.datatype = dt; coll_args.dst.info.mem_type = mt; + if (is_inplace) { - coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + + if (is_persistent) { + coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } } ucc_status_t ucc_pt_coll_alltoall::init_args(size_t single_rank_count, diff --git a/tools/perf/ucc_pt_coll_alltoallv.cc b/tools/perf/ucc_pt_coll_alltoallv.cc index 4ba88ec123..6ce68ed032 100644 --- a/tools/perf/ucc_pt_coll_alltoallv.cc +++ b/tools/perf/ucc_pt_coll_alltoallv.cc @@ -12,6 +12,7 @@ ucc_pt_coll_alltoallv::ucc_pt_coll_alltoallv(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, + bool is_persistent, ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { has_inplace_ = true; @@ -31,6 +32,11 @@ ucc_pt_coll_alltoallv::ucc_pt_coll_alltoallv(ucc_datatype_t dt, if (is_inplace) { coll_args.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; } + + if (is_persistent) { + coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } + } ucc_status_t ucc_pt_coll_alltoallv::init_args(size_t count, diff --git a/tools/perf/ucc_pt_coll_bcast.cc b/tools/perf/ucc_pt_coll_bcast.cc index b389228c38..b869c902c1 100644 --- a/tools/perf/ucc_pt_coll_bcast.cc +++ b/tools/perf/ucc_pt_coll_bcast.cc @@ -11,7 +11,8 @@ #include ucc_pt_coll_bcast::ucc_pt_coll_bcast(ucc_datatype_t dt, ucc_memory_type mt, - int root_shift, ucc_pt_comm *communicator) + int root_shift, bool is_persistent, + ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { has_inplace_ = false; @@ -20,10 +21,16 @@ ucc_pt_coll_bcast::ucc_pt_coll_bcast(ucc_datatype_t dt, ucc_memory_type mt, has_bw_ = true; root_shift_ = root_shift; - coll_args.mask = 0; - coll_args.coll_type = UCC_COLL_TYPE_BCAST; + coll_args.mask = 0; + coll_args.flags = 0; + coll_args.coll_type = UCC_COLL_TYPE_BCAST; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; + + if (is_persistent) { + coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } } ucc_status_t ucc_pt_coll_bcast::init_args(size_t count, diff --git a/tools/perf/ucc_pt_coll_gather.cc b/tools/perf/ucc_pt_coll_gather.cc index e189164484..660356bee8 100644 --- a/tools/perf/ucc_pt_coll_gather.cc +++ b/tools/perf/ucc_pt_coll_gather.cc @@ -11,7 +11,8 @@ #include ucc_pt_coll_gather::ucc_pt_coll_gather(ucc_datatype_t dt, - ucc_memory_type mt, bool is_inplace, int root_shift, + ucc_memory_type mt, bool is_inplace, + bool is_persistent, int root_shift, ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { has_inplace_ = true; @@ -21,15 +22,22 @@ ucc_pt_coll_gather::ucc_pt_coll_gather(ucc_datatype_t dt, root_shift_ = root_shift; coll_args.mask = 0; + coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_GATHER; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info.datatype = dt; coll_args.dst.info.mem_type = mt; + if (is_inplace) { coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + + if (is_persistent) { + coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } } ucc_status_t ucc_pt_coll_gather::init_args(size_t single_rank_count, diff --git a/tools/perf/ucc_pt_coll_gatherv.cc b/tools/perf/ucc_pt_coll_gatherv.cc index 6739f241d6..ab8715b3cc 100644 --- a/tools/perf/ucc_pt_coll_gatherv.cc +++ b/tools/perf/ucc_pt_coll_gatherv.cc @@ -11,7 +11,8 @@ #include ucc_pt_coll_gatherv::ucc_pt_coll_gatherv(ucc_datatype_t dt, - ucc_memory_type mt, bool is_inplace, int root_shift, + ucc_memory_type mt, bool is_inplace, + bool is_persistent, int root_shift, ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { has_inplace_ = true; @@ -21,15 +22,22 @@ ucc_pt_coll_gatherv::ucc_pt_coll_gatherv(ucc_datatype_t dt, root_shift_ = root_shift; coll_args.mask = 0; + coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_GATHERV; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info_v.datatype = dt; coll_args.dst.info_v.mem_type = mt; + if (is_inplace) { coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + + if (is_persistent) { + coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } } ucc_status_t ucc_pt_coll_gatherv::init_args(size_t count, diff --git a/tools/perf/ucc_pt_coll_reduce.cc b/tools/perf/ucc_pt_coll_reduce.cc index ad013bab67..47610bb68c 100644 --- a/tools/perf/ucc_pt_coll_reduce.cc +++ b/tools/perf/ucc_pt_coll_reduce.cc @@ -11,7 +11,8 @@ #include ucc_pt_coll_reduce::ucc_pt_coll_reduce(ucc_datatype_t dt, ucc_memory_type mt, - ucc_reduction_op_t op, bool is_inplace, int root_shift, + ucc_reduction_op_t op, bool is_inplace, + bool is_persistent, int root_shift, ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { has_inplace_ = true; @@ -20,18 +21,24 @@ ucc_pt_coll_reduce::ucc_pt_coll_reduce(ucc_datatype_t dt, ucc_memory_type mt, has_bw_ = true; root_shift_ = root_shift; - coll_args.coll_type = UCC_COLL_TYPE_REDUCE; - coll_args.mask = 0; + coll_args.mask = 0; + coll_args.flags = 0; + coll_args.coll_type = UCC_COLL_TYPE_REDUCE; + coll_args.op = op; + coll_args.src.info.datatype = dt; + coll_args.src.info.mem_type = mt; + coll_args.dst.info.datatype = dt; + coll_args.dst.info.mem_type = mt; + if (is_inplace) { coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } - coll_args.op = op; - coll_args.src.info.datatype = dt; - coll_args.src.info.mem_type = mt; - coll_args.dst.info.datatype = dt; - coll_args.dst.info.mem_type = mt; + if (is_persistent) { + coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } } ucc_status_t ucc_pt_coll_reduce::init_args(size_t count, diff --git a/tools/perf/ucc_pt_coll_reduce_scatter.cc b/tools/perf/ucc_pt_coll_reduce_scatter.cc index 8c51a5ffbd..e15bf80bcb 100644 --- a/tools/perf/ucc_pt_coll_reduce_scatter.cc +++ b/tools/perf/ucc_pt_coll_reduce_scatter.cc @@ -12,7 +12,7 @@ ucc_pt_coll_reduce_scatter::ucc_pt_coll_reduce_scatter(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, - bool is_inplace, + bool is_inplace, bool is_persistent, ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { has_inplace_ = true; @@ -21,18 +21,24 @@ ucc_pt_coll_reduce_scatter::ucc_pt_coll_reduce_scatter(ucc_datatype_t dt, has_bw_ = true; root_shift_ = 0; - coll_args.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER; - coll_args.mask = 0; - if (is_inplace) { - coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; - coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } - + coll_args.mask = 0; + coll_args.flags = 0; + coll_args.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER; coll_args.op = op; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info.datatype = dt; coll_args.dst.info.mem_type = mt; + + if (is_inplace) { + coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; + } + + if (is_persistent) { + coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } } ucc_status_t ucc_pt_coll_reduce_scatter::init_args(size_t count, diff --git a/tools/perf/ucc_pt_coll_reduce_scatterv.cc b/tools/perf/ucc_pt_coll_reduce_scatterv.cc index 84f55a2132..932ad600d9 100644 --- a/tools/perf/ucc_pt_coll_reduce_scatterv.cc +++ b/tools/perf/ucc_pt_coll_reduce_scatterv.cc @@ -12,7 +12,7 @@ ucc_pt_coll_reduce_scatterv::ucc_pt_coll_reduce_scatterv(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, - bool is_inplace, + bool is_inplace, bool is_persistent, ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { has_inplace_ = true; @@ -21,18 +21,24 @@ ucc_pt_coll_reduce_scatterv::ucc_pt_coll_reduce_scatterv(ucc_datatype_t dt, has_bw_ = false; root_shift_ = 0; - coll_args.coll_type = UCC_COLL_TYPE_REDUCE_SCATTERV; - coll_args.mask = 0; - if (is_inplace) { - coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; - coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } - + coll_args.mask = 0; + coll_args.flags = 0; + coll_args.coll_type = UCC_COLL_TYPE_REDUCE_SCATTERV; coll_args.op = op; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info_v.datatype = dt; coll_args.dst.info_v.mem_type = mt; + + if (is_inplace) { + coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; + } + + if (is_persistent) { + coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } } ucc_status_t ucc_pt_coll_reduce_scatterv::init_args(size_t count, diff --git a/tools/perf/ucc_pt_coll_scatter.cc b/tools/perf/ucc_pt_coll_scatter.cc index 4d66f51d99..ac414dd2ed 100644 --- a/tools/perf/ucc_pt_coll_scatter.cc +++ b/tools/perf/ucc_pt_coll_scatter.cc @@ -11,7 +11,8 @@ #include ucc_pt_coll_scatter::ucc_pt_coll_scatter(ucc_datatype_t dt, - ucc_memory_type mt, bool is_inplace, int root_shift, + ucc_memory_type mt, bool is_inplace, + bool is_persistent, int root_shift, ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { has_inplace_ = true; @@ -21,15 +22,22 @@ ucc_pt_coll_scatter::ucc_pt_coll_scatter(ucc_datatype_t dt, root_shift_ = root_shift; coll_args.mask = 0; + coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_SCATTER; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info.datatype = dt; coll_args.dst.info.mem_type = mt; + if (is_inplace) { coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + + if (is_persistent) { + coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } } ucc_status_t ucc_pt_coll_scatter::init_args(size_t single_rank_count, diff --git a/tools/perf/ucc_pt_coll_scatterv.cc b/tools/perf/ucc_pt_coll_scatterv.cc index 328752022c..1dc9bf7db9 100644 --- a/tools/perf/ucc_pt_coll_scatterv.cc +++ b/tools/perf/ucc_pt_coll_scatterv.cc @@ -11,7 +11,8 @@ #include ucc_pt_coll_scatterv::ucc_pt_coll_scatterv(ucc_datatype_t dt, - ucc_memory_type mt, bool is_inplace, int root_shift, + ucc_memory_type mt, bool is_inplace, + bool is_persistent, int root_shift, ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { has_inplace_ = true; @@ -21,15 +22,22 @@ ucc_pt_coll_scatterv::ucc_pt_coll_scatterv(ucc_datatype_t dt, root_shift_ = root_shift; coll_args.mask = 0; + coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_SCATTERV; coll_args.src.info_v.datatype = dt; coll_args.src.info_v.mem_type = mt; coll_args.dst.info.datatype = dt; coll_args.dst.info.mem_type = mt; + if (is_inplace) { coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + + if (is_persistent) { + coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } } ucc_status_t ucc_pt_coll_scatterv::init_args(size_t count, diff --git a/tools/perf/ucc_pt_config.cc b/tools/perf/ucc_pt_config.cc index 3fcb2b01c9..e59b62ce26 100644 --- a/tools/perf/ucc_pt_config.cc +++ b/tools/perf/ucc_pt_config.cc @@ -18,6 +18,7 @@ ucc_pt_config::ucc_pt_config() { bench.mt = UCC_MEMORY_TYPE_HOST; bench.op = UCC_OP_SUM; bench.inplace = false; + bench.persistent = false; bench.triggered = false; bench.n_iter_small = 1000; bench.n_warmup_small = 100; @@ -89,7 +90,7 @@ ucc_status_t ucc_pt_config::process_args(int argc, char *argv[]) int c; ucc_status_t st; - while ((c = getopt(argc, argv, "c:b:e:d:m:n:w:o:N:r:S:ihFT")) != -1) { + while ((c = getopt(argc, argv, "c:b:e:d:m:n:w:o:N:r:S:iphFT")) != -1) { switch (c) { case 'c': if (ucc_pt_op_map.count(optarg) == 0) { @@ -158,6 +159,9 @@ ucc_status_t ucc_pt_config::process_args(int argc, char *argv[]) case 'i': bench.inplace = true; break; + case 'p': + bench.persistent = true; + break; case 'T': bench.triggered = true; break; @@ -180,6 +184,7 @@ void ucc_pt_config::print_help() std::cout << " -b : Min number of elements"<: Max number of elements"<: datatype"<: reduction operation type"<: root for rooted collectives"<