Skip to content

Commit

Permalink
MC/CUDA: add uint16(32,64) support in reduce (#565)
Browse files Browse the repository at this point in the history
* MC/CUDA: add uint16(32,64) support in reduce

* TEST: add CUDA reduce gtest with uint16(32,64) dt

* TEST: add reduce mpi tests with uint16(32,64) dt

Co-authored-by: valentin petrov <[email protected]>
  • Loading branch information
samnordmann and valentin petrov authored Jul 20, 2022
1 parent 786275c commit 5cf5815
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 4 deletions.
9 changes: 9 additions & 0 deletions src/components/mc/cuda/kernel/mc_cuda_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ ucc_status_t ucc_mc_cuda_reduce(const void *src1, const void *src2, void *dst,
case UCC_DT_INT64:
DT_REDUCE_INT(int64_t, op, src1, src2, dst, count, stream, bk, th);
break;
case UCC_DT_UINT16:
DT_REDUCE_INT(uint16_t, op, src1, src2, dst, count, stream, bk, th);
break;
case UCC_DT_UINT32:
DT_REDUCE_INT(uint32_t, op, src1, src2, dst, count, stream, bk, th);
break;
case UCC_DT_UINT64:
DT_REDUCE_INT(uint64_t, op, src1, src2, dst, count, stream, bk, th);
break;
case UCC_DT_FLOAT16:
DT_REDUCE_FLOAT(__half, op, src1, src2, dst, count, stream, bk, th);
break;
Expand Down
12 changes: 12 additions & 0 deletions src/components/mc/cuda/kernel/mc_cuda_reduce_multi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,18 @@ ucc_status_t ucc_mc_cuda_reduce_multi(const void *src1, const void *src2,
DT_REDUCE_INT(int64_t, op, src1, src2, dst, n_vectors, count, ld,
stream, bk, th);
break;
case UCC_DT_UINT16:
DT_REDUCE_INT(uint16_t, op, src1, src2, dst, n_vectors, count, ld,
stream, bk, th);
break;
case UCC_DT_UINT32:
DT_REDUCE_INT(uint32_t, op, src1, src2, dst, n_vectors, count, ld,
stream, bk, th);
break;
case UCC_DT_UINT64:
DT_REDUCE_INT(uint64_t, op, src1, src2, dst, n_vectors, count, ld,
stream, bk, th);
break;
case UCC_DT_FLOAT16:
DT_REDUCE_FLOAT(__half, op, src1, src2, dst, n_vectors, count, ld,
stream, bk, th);
Expand Down
2 changes: 2 additions & 0 deletions test/gtest/core/test_mc_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ DECLARE_OP_(avg, AVG, SUM);
#define CUDA_OP_PAIRS \
TypeOpPair<UCC_DT_INT16, lor>, TypeOpPair<UCC_DT_INT16, sum>, \
TypeOpPair<UCC_DT_INT32, prod>, TypeOpPair<UCC_DT_INT64, bxor>, \
TypeOpPair<UCC_DT_UINT16, lor>, TypeOpPair<UCC_DT_UINT16, sum>, \
TypeOpPair<UCC_DT_UINT32, prod>, TypeOpPair<UCC_DT_UINT64, bxor>, \
TypeOpPair<UCC_DT_FLOAT32, avg>, TypeOpPair<UCC_DT_FLOAT64, avg>, \
ARITHMETIC_OP_PAIRS(INT32), ARITHMETIC_OP_PAIRS(FLOAT32), \
ARITHMETIC_OP_PAIRS(FLOAT64), ARITHMETIC_OP_PAIRS(BFLOAT16), \
Expand Down
7 changes: 4 additions & 3 deletions test/mpi/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ static std::vector<ucc_coll_type_t> colls = {
UCC_COLL_TYPE_SCATTER, UCC_COLL_TYPE_SCATTERV};
static std::vector<ucc_coll_type_t> onesided_colls = {UCC_COLL_TYPE_ALLTOALL};
static std::vector<ucc_memory_type_t> mtypes = {UCC_MEMORY_TYPE_HOST};
static std::vector<ucc_datatype_t> dtypes = {UCC_DT_INT32, UCC_DT_INT64,
UCC_DT_FLOAT32, UCC_DT_FLOAT64,
UCC_DT_FLOAT64_COMPLEX};
static std::vector<ucc_datatype_t> dtypes = {
UCC_DT_INT16, UCC_DT_INT32, UCC_DT_INT64,
UCC_DT_UINT16, UCC_DT_UINT32, UCC_DT_UINT64,
UCC_DT_FLOAT32, UCC_DT_FLOAT64, UCC_DT_FLOAT64_COMPLEX};
static std::vector<ucc_reduction_op_t> ops = {UCC_OP_SUM, UCC_OP_MAX,
UCC_OP_AVG};
static std::vector<ucc_test_mpi_team_t> teams = {TEAM_WORLD, TEAM_REVERSE,
Expand Down
4 changes: 3 additions & 1 deletion test/mpi/test_mpi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ UccTestMpi::UccTestMpi(int argc, char *argv[], ucc_thread_mode_t _tm,
onesided_ctx = nullptr;
}
set_msgsizes(8, ((1ULL) << 21), 8);
dtypes = {UCC_DT_INT32, UCC_DT_INT64,
dtypes = {UCC_DT_INT16, UCC_DT_INT32,
UCC_DT_INT64, UCC_DT_UINT16,
UCC_DT_UINT32, UCC_DT_UINT64,
UCC_DT_FLOAT32, UCC_DT_FLOAT64,
UCC_DT_FLOAT128, UCC_DT_FLOAT32_COMPLEX,
UCC_DT_FLOAT64_COMPLEX, UCC_DT_FLOAT128_COMPLEX};
Expand Down

0 comments on commit 5cf5815

Please sign in to comment.