Skip to content

Commit

Permalink
TEST: do local check in mpi tests (#789)
Browse files Browse the repository at this point in the history
* TEST: do local check in mpi tests

* REVIEW: fix review comments
  • Loading branch information
Sergei-Lebedev authored Nov 1, 2023
1 parent 483b91b commit 7aefb67
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 34 deletions.
6 changes: 4 additions & 2 deletions test/mpi/buffer.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -25,7 +25,7 @@ void init_buffer_host(void *buf, size_t count, int _value)
}

void init_buffer(void *_buf, size_t count, ucc_datatype_t dt,
ucc_memory_type_t mt, int value)
ucc_memory_type_t mt, int value, int offset)
{
void *buf = NULL;
if (mt == UCC_MEMORY_TYPE_CUDA || mt == UCC_MEMORY_TYPE_ROCM) {
Expand All @@ -37,6 +37,8 @@ void init_buffer(void *_buf, size_t count, ucc_datatype_t dt,
std::cerr << "Unsupported mt\n";
MPI_Abort(MPI_COMM_WORLD, -1);
}

value += offset;
switch(dt) {
case UCC_DT_INT8:
init_buffer_host<int8_t>(buf, count, value);
Expand Down
21 changes: 10 additions & 11 deletions test/mpi/test_allgather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ ucc_status_t TestAllgather::set_input(int iter_persistent)
int rank;
void *buf, *check;

this->iter_persistent = iter_persistent;
MPI_Comm_rank(team.comm, &rank);
if (inplace) {
buf = PTR_OFFSET(rbuf, rank * single_rank_size);
Expand All @@ -70,18 +71,16 @@ ucc_status_t TestAllgather::set_input(int iter_persistent)

ucc_status_t TestAllgather::check()
{
int size, completed;
MPI_Comm_size(team.comm, &size);
size_t single_rank_count = args.dst.info.count / size;
MPI_Datatype mpi_dt = ucc_dt_to_mpi(dt);
MPI_Request req;
size_t dt_size, single_rank_count;
int size, i;

MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, check_buf,
single_rank_count, mpi_dt, team.comm, &req);
do {
MPI_Test(&req, &completed, MPI_STATUS_IGNORE);
ucc_context_progress(team.ctx);
} while(!completed);
MPI_Comm_size(team.comm, &size);
single_rank_count = args.dst.info.count / size;
dt_size = ucc_dt_size(dt);
for (i = 0; i < size; i++) {
init_buffer(PTR_OFFSET(check_buf, i * single_rank_count * dt_size),
single_rank_count, dt, mem_type, i * (iter_persistent + 1));
}

return compare_buffers(rbuf, check_buf, single_rank_count * size, dt,
mem_type);
Expand Down
18 changes: 9 additions & 9 deletions test/mpi/test_alltoall.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ ucc_status_t TestAlltoall::set_input(int iter_persistent)
void * buf;
int rank, nprocs, completed;

this->iter_persistent = iter_persistent;
MPI_Comm_rank(team.comm, &rank);
MPI_Comm_size(team.comm, &nprocs);
if (inplace) {
Expand All @@ -99,19 +100,18 @@ ucc_status_t TestAlltoall::set_input(int iter_persistent)

ucc_status_t TestAlltoall::check()
{
int size, completed;
size_t single_rank_count;
MPI_Request req;
int size, rank, i;
size_t single_rank_count;

MPI_Comm_rank(team.comm, &rank);
MPI_Comm_size(team.comm, &size);
single_rank_count = args.src.info.count / size;

MPI_Ialltoall(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, check_buf,
single_rank_count, ucc_dt_to_mpi(dt), team.comm, &req);
do {
MPI_Test(&req, &completed, MPI_STATUS_IGNORE);
ucc_context_progress(team.ctx);
} while(!completed);
for ( i = 0; i < size; i++) {
init_buffer(PTR_OFFSET(check_buf, i * single_rank_count * ucc_dt_size(dt)),
single_rank_count, dt, mem_type, i * (iter_persistent + 1),
single_rank_count * rank);
}

return compare_buffers(rbuf, check_buf, single_rank_count * size, dt,
mem_type);
Expand Down
15 changes: 5 additions & 10 deletions test/mpi/test_bcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ ucc_status_t TestBcast::set_input(int iter_persistent)
size_t count = msgsize / dt_size;
int rank;

this->iter_persistent = iter_persistent;
MPI_Comm_rank(team.comm, &rank);
if (rank == root) {
init_buffer(sbuf, count, dt, mem_type, rank * (iter_persistent + 1));
Expand All @@ -56,18 +57,12 @@ ucc_status_t TestBcast::set_input(int iter_persistent)

ucc_status_t TestBcast::check()
{
size_t count = args.src.info.count;
MPI_Datatype mpi_dt = ucc_dt_to_mpi(dt);
int rank, completed;
MPI_Request req;
size_t count = args.src.info.count;
int rank;

MPI_Comm_rank(team.comm, &rank);
MPI_Ibcast(check_buf, count, mpi_dt, root, team.comm, &req);
do {
MPI_Test(&req, &completed, MPI_STATUS_IGNORE);
ucc_context_progress(team.ctx);
} while(!completed);

init_buffer(check_buf, count, dt, UCC_MEMORY_TYPE_HOST,
root * (iter_persistent + 1));
return (rank == root)
? UCC_OK
: compare_buffers(sbuf, check_buf, count, dt, mem_type);
Expand Down
4 changes: 2 additions & 2 deletions test/mpi/test_mpi.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ class TestCase {
uint8_t progress_buf[1];
size_t test_max_size;
ucc_datatype_t dt;

int iter_persistent;
public:
void mpi_progress(void);
test_skip_cause_t test_skip;
Expand Down Expand Up @@ -523,7 +523,7 @@ class TestScatterv : public TestCase {
};

void init_buffer(void *buf, size_t count, ucc_datatype_t dt,
ucc_memory_type_t mt, int value);
ucc_memory_type_t mt, int value, int offset = 0);

ucc_status_t compare_buffers(void *rst, void *expected, size_t count,
ucc_datatype_t dt, ucc_memory_type_t mt);
Expand Down

0 comments on commit 7aefb67

Please sign in to comment.