Skip to content

Commit

Permalink
Add noncontig allgatherv tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarkauskas committed Nov 26, 2024
1 parent cf77c01 commit 4074a5b
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 16 deletions.
6 changes: 6 additions & 0 deletions src/components/cl/hier/allgatherv/allgatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
ucc_count_t *leader_counts = NULL;
size_t dt_size = ucc_dt_size(coll_args->args.dst.info_v.datatype);
int in_place = 0;
int is_contig = 0;
ucc_schedule_t *schedule;
ucc_cl_hier_schedule_t *cl_schedule;
ucc_status_t status;
Expand All @@ -128,6 +129,7 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
memcpy(&args, coll_args, sizeof(args));
memcpy(&args_old, coll_args, sizeof(args));
in_place = UCC_IS_INPLACE(args.args);
is_contig = UCC_COLL_IS_DST_CONTIG(&args.args);
n_tasks = 0;
UCC_CHECK_GOTO(ucc_schedule_init(schedule, &args, team), out, status);

Expand Down Expand Up @@ -245,6 +247,10 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
n_tasks++;
}

if (!is_contig) {
printf("not contig, scheduling unpack operation\n");
}

UCC_CHECK_GOTO(ucc_event_manager_subscribe(
&schedule->super, UCC_EVENT_SCHEDULE_STARTED, tasks[0],
ucc_task_start_handler),
Expand Down
62 changes: 46 additions & 16 deletions test/gtest/coll/test_allgatherv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
#include "common/test_ucc.h"
#include "utils/ucc_math.h"

using Param_0 = std::tuple<int, ucc_datatype_t, ucc_memory_type_t, int, gtest_ucc_inplace_t>;
using Param_1 = std::tuple<ucc_datatype_t, ucc_memory_type_t, int, gtest_ucc_inplace_t>;
using Param_2 = std::tuple<ucc_datatype_t, ucc_memory_type_t, int, gtest_ucc_inplace_t, std::string>;
using Param_0 = std::tuple<int, ucc_datatype_t, ucc_memory_type_t, int, gtest_ucc_inplace_t, bool>;
using Param_1 = std::tuple<ucc_datatype_t, ucc_memory_type_t, int, gtest_ucc_inplace_t, bool>;
using Param_2 = std::tuple<ucc_datatype_t, ucc_memory_type_t, int, gtest_ucc_inplace_t, std::string, bool>;

size_t noncontig_padding = 1; // # elements worth of space in between each rank's contribution to the dst buf

class test_allgatherv : public UccCollArgs, public ucc::test
{
Expand All @@ -21,7 +23,8 @@ class test_allgatherv : public UccCollArgs, public ucc::test
int *counts;
int *displs;
size_t my_count = (nprocs - r) * count;
size_t all_counts = 0;
size_t disp_counter = 0;
size_t noncontig_total_padding = noncontig_padding * nprocs;
ucc_coll_args_t *coll = (ucc_coll_args_t*)calloc(1, sizeof(ucc_coll_args_t));

ctxs[r] = (gtest_ucc_coll_ctx_t*)calloc(1, sizeof(gtest_ucc_coll_ctx_t));
Expand All @@ -30,13 +33,21 @@ class test_allgatherv : public UccCollArgs, public ucc::test
counts = (int*)malloc(sizeof(int) * nprocs);
displs = (int*)malloc(sizeof(int) * nprocs);

for (int i = 0; i < nprocs; i++) {
counts[i] = (nprocs - i) * count;
displs[i] = all_counts;
all_counts += counts[i];
if (is_contig) {
for (int i = 0; i < nprocs; i++) {
counts[i] = (nprocs - i) * count;
displs[i] = disp_counter;
disp_counter += counts[i];
}
coll->flags = UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER;
} else {
for (int i = 0; i < nprocs; i++) {
counts[i] = (nprocs - i) * count;
displs[i] = disp_counter;
disp_counter += counts[i] + noncontig_padding; // Add noncontig_padding elemnts of space between the bufs
}
}
coll->mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll->flags = UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER;
coll->coll_type = UCC_COLL_TYPE_ALLGATHERV;

coll->src.info.mem_type = mem_type;
Expand All @@ -48,14 +59,14 @@ class test_allgatherv : public UccCollArgs, public ucc::test
coll->dst.info_v.displacements = (ucc_aint_t*)displs;
coll->dst.info_v.datatype = dtype;

ctxs[r]->init_buf = ucc_malloc(ucc_dt_size(dtype) * my_count, "init buf");
ctxs[r]->init_buf = ucc_malloc(ucc_dt_size(dtype) * (my_count + noncontig_total_padding), "init buf");
EXPECT_NE(ctxs[r]->init_buf, nullptr);
for (int i = 0; i < (ucc_dt_size(dtype) * my_count); i++) {
uint8_t *sbuf = (uint8_t*)ctxs[r]->init_buf;
sbuf[i] = r;
}

ctxs[r]->rbuf_size = ucc_dt_size(dtype) * all_counts;
ctxs[r]->rbuf_size = ucc_dt_size(dtype) * disp_counter;
UCC_CHECK(ucc_mc_alloc(&ctxs[r]->dst_mc_header, ctxs[r]->rbuf_size,
mem_type));
coll->dst.info_v.buffer = ctxs[r]->dst_mc_header->addr;
Expand Down Expand Up @@ -138,6 +149,7 @@ class test_allgatherv : public UccCollArgs, public ucc::test
for (int i = 0; i < ctxs.size(); i++) {
size_t rank_size = 0;
uint8_t *rbuf = dsts[i];
int is_contig = UCC_COLL_IS_DST_CONTIG(ctxs[i]->args);
for (int r = 0; r < ctxs.size(); r++) {
rbuf += rank_size;
rank_size = ucc_dt_size((ctxs[r])->args->src.info.datatype) *
Expand All @@ -148,6 +160,9 @@ class test_allgatherv : public UccCollArgs, public ucc::test
break;
}
}
if (!is_contig) {
rbuf += noncontig_padding * ucc_dt_size((ctxs[r])->args->src.info.datatype);
}
}
}
if (UCC_MEMORY_TYPE_HOST != mem_type) {
Expand All @@ -169,11 +184,13 @@ UCC_TEST_P(test_allgatherv_0, single)
const ucc_memory_type_t mem_type = std::get<2>(GetParam());
const int count = std::get<3>(GetParam());
const gtest_ucc_inplace_t inplace = std::get<4>(GetParam());
const bool contig = std::get<5>(GetParam());
UccTeam_h team = UccJob::getStaticTeams()[team_id];
int size = team->procs.size();
UccCollCtxVec ctxs;

set_inplace(inplace);
set_contig(contig);
SET_MEM_TYPE(mem_type);

data_init(size, dtype, count, ctxs, false);
Expand All @@ -191,12 +208,14 @@ UCC_TEST_P(test_allgatherv_0, single_persistent)
const ucc_memory_type_t mem_type = std::get<2>(GetParam());
const int count = std::get<3>(GetParam());
const gtest_ucc_inplace_t inplace = std::get<4>(GetParam());
const bool contig = std::get<5>(GetParam());
UccTeam_h team = UccJob::getStaticTeams()[team_id];
int size = team->procs.size();
const int n_calls = 3;
UccCollCtxVec ctxs;

set_inplace(inplace);
set_contig(contig);
SET_MEM_TYPE(mem_type);

data_init(size, dtype, count, ctxs, true);
Expand All @@ -223,7 +242,9 @@ INSTANTIATE_TEST_CASE_P(
::testing::Values(UCC_MEMORY_TYPE_HOST),
#endif
::testing::Values(1,3,8192), // count
::testing::Values(TEST_INPLACE, TEST_NO_INPLACE))); // inplace
::testing::Values(TEST_INPLACE, TEST_NO_INPLACE), // inplace
::testing::Values(false, true) // contig dst buf displacements
));

class test_allgatherv_1 : public test_allgatherv,
public ::testing::WithParamInterface<Param_1> {};
Expand All @@ -234,6 +255,7 @@ UCC_TEST_P(test_allgatherv_1, multiple)
const ucc_memory_type_t mem_type = std::get<1>(GetParam());
const int count = std::get<2>(GetParam());
const gtest_ucc_inplace_t inplace = std::get<3>(GetParam());
const bool contig = std::get<4>(GetParam());
std::vector<UccReq> reqs;
std::vector<UccCollCtxVec> ctxs;

Expand All @@ -243,6 +265,7 @@ UCC_TEST_P(test_allgatherv_1, multiple)
UccCollCtxVec ctx;

this->set_inplace(inplace);
this->set_contig(contig);
SET_MEM_TYPE(mem_type);

data_init(size, dtype, count, ctx, false);
Expand All @@ -269,7 +292,9 @@ INSTANTIATE_TEST_CASE_P(
::testing::Values(UCC_MEMORY_TYPE_HOST),
#endif
::testing::Values(1,3,8192), // count
::testing::Values(TEST_INPLACE, TEST_NO_INPLACE)));
::testing::Values(TEST_INPLACE, TEST_NO_INPLACE),
::testing::Values(false, true)) // dst buf contig
);

class test_allgatherv_alg : public test_allgatherv,
public ::testing::WithParamInterface<Param_2> {};
Expand All @@ -280,6 +305,7 @@ UCC_TEST_P(test_allgatherv_alg, alg)
const ucc_memory_type_t mem_type = std::get<1>(GetParam());
const int count = std::get<2>(GetParam());
const gtest_ucc_inplace_t inplace = std::get<3>(GetParam());
const bool contig = std::get<5>(GetParam());
int n_procs = 5;
char tune[32];

Expand All @@ -291,13 +317,14 @@ UCC_TEST_P(test_allgatherv_alg, alg)
UccCollCtxVec ctxs;

set_inplace(inplace);
set_contig(contig);
SET_MEM_TYPE(mem_type);

data_init(n_procs, dtype, count, ctxs, false);
UccReq req(team, ctxs);
req.start();
req.wait();
EXPECT_EQ(true, data_validate(ctxs));;
EXPECT_EQ(true, data_validate(ctxs));
data_fini(ctxs);
}

Expand All @@ -313,13 +340,16 @@ INSTANTIATE_TEST_CASE_P(
#endif
::testing::Values(1,3,8192), // count
::testing::Values(TEST_INPLACE, TEST_NO_INPLACE),
::testing::Values("knomial", "ring")),
::testing::Values("knomial", "ring"),
::testing::Values(false, true)), // dst buf contig
[](const testing::TestParamInfo<test_allgatherv_alg::ParamType>& info) {
std::string name;
name += ucc_datatype_str(std::get<0>(info.param));
name += std::string("_") + std::string(ucc_mem_type_str(std::get<1>(info.param)));
name += std::string("_count_")+std::to_string(std::get<2>(info.param));
name += std::string("_inplace_")+std::to_string(std::get<3>(info.param));
name += std::string("_contig_")+std::to_string(std::get<5>(info.param));
name += std::string("_")+std::get<4>(info.param);
return name;
});
}
);
5 changes: 5 additions & 0 deletions test/gtest/common/test_ucc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,11 @@ void UccCollArgs::set_inplace(gtest_ucc_inplace_t _inplace)
inplace = _inplace;
}

void UccCollArgs::set_contig(bool _is_contig)
{
is_contig = _is_contig;
}

void clear_buffer(void *_buf, size_t size, ucc_memory_type_t mt, uint8_t value)
{
void *buf = _buf;
Expand Down
2 changes: 2 additions & 0 deletions test/gtest/common/test_ucc.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class UccCollArgs {
protected:
ucc_memory_type_t mem_type;
gtest_ucc_inplace_t inplace;
bool is_contig;
void alltoallx_init_buf(int src_rank, int dst_rank, uint8_t *buf, size_t len)
{
for (int i = 0; i < len; i++) {
Expand Down Expand Up @@ -74,6 +75,7 @@ class UccCollArgs {
virtual bool data_validate(UccCollCtxVec args) = 0;
void set_mem_type(ucc_memory_type_t _mt);
void set_inplace(gtest_ucc_inplace_t _inplace);
void set_contig(bool _contig);
};

#define SET_MEM_TYPE(_mt) do { \
Expand Down

0 comments on commit 4074a5b

Please sign in to comment.