Skip to content

Commit

Permalink
Add helper test
Browse files Browse the repository at this point in the history
Signed-off-by: Jiang, Zhiwei <[email protected]>
  • Loading branch information
zhiweij1 committed Jun 27, 2024
1 parent 3c6b32a commit 0a12c0d
Show file tree
Hide file tree
Showing 2 changed files with 398 additions and 0 deletions.
199 changes: 199 additions & 0 deletions help_function/src/sparse_utils_2_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1967,6 +1967,204 @@ void test_cusparseSpSV() {
}
}

// A * B = C
//
// | 0 1 2 | | 1 0 0 0 | | 2 3 10 12 |
// | 0 0 3 | * | 2 3 0 0 | = | 0 0 15 18 |
// | 4 0 0 | | 0 0 5 6 | | 4 0 0 0 |
void test_cusparseTcsrgemm() {
dpct::device_ext &dev_ct1 = dpct::get_current_device();
sycl::queue &q_ct1 = dev_ct1.out_of_order_queue();
std::vector<float> a_val_vec = {1, 2, 3, 4};
Data<float> a_s_val(a_val_vec.data(), 4);
Data<double> a_d_val(a_val_vec.data(), 4);
Data<sycl::float2> a_c_val(a_val_vec.data(), 4);
Data<sycl::double2> a_z_val(a_val_vec.data(), 4);
std::vector<float> a_row_ptr_vec = {0, 2, 3, 4};
Data<int> a_row_ptr(a_row_ptr_vec.data(), 4);
std::vector<float> a_col_ind_vec = {1, 2, 2, 0};
Data<int> a_col_ind(a_col_ind_vec.data(), 4);

std::vector<float> b_val_vec = {1, 2, 3, 5, 6};
Data<float> b_s_val(b_val_vec.data(), 5);
Data<double> b_d_val(b_val_vec.data(), 5);
Data<sycl::float2> b_c_val(b_val_vec.data(), 5);
Data<sycl::double2> b_z_val(b_val_vec.data(), 5);
std::vector<float> b_row_ptr_vec = {0, 1, 3, 5};
Data<int> b_row_ptr(b_row_ptr_vec.data(), 4);
std::vector<float> b_col_ind_vec = {0, 0, 1, 2, 3};
Data<int> b_col_ind(b_col_ind_vec.data(), 5);

float alpha = 1;
Data<float> alpha_s(&alpha, 1);
Data<double> alpha_d(&alpha, 1);
Data<sycl::float2> alpha_c(&alpha, 1);
Data<sycl::double2> alpha_z(&alpha, 1);

float beta = 0;
Data<float> beta_s(&beta, 1);
Data<double> beta_d(&beta, 1);
Data<sycl::float2> beta_c(&beta, 1);
Data<sycl::double2> beta_z(&beta, 1);

dpct::sparse::descriptor_ptr handle;
handle = new dpct::sparse::descriptor();

/*
DPCT1026:38: The call to cusparseSetPointerMode was removed because this
functionality is redundant in SYCL.
*/

a_s_val.H2D();
a_d_val.H2D();
a_c_val.H2D();
a_z_val.H2D();
a_row_ptr.H2D();
a_col_ind.H2D();
b_s_val.H2D();
b_d_val.H2D();
b_c_val.H2D();
b_z_val.H2D();
b_row_ptr.H2D();
b_col_ind.H2D();
alpha_s.H2D();
alpha_d.H2D();
alpha_c.H2D();
alpha_z.H2D();
beta_s.H2D();
beta_d.H2D();
beta_c.H2D();
beta_z.H2D();

Data<int> c_s_row_ptr(4);
Data<int> c_d_row_ptr(4);
Data<int> c_c_row_ptr(4);
Data<int> c_z_row_ptr(4);

std::shared_ptr<dpct::sparse::matrix_info> descrA;
std::shared_ptr<dpct::sparse::matrix_info> descrB;
std::shared_ptr<dpct::sparse::matrix_info> descrC;
descrA = std::make_shared<dpct::sparse::matrix_info>();
descrB = std::make_shared<dpct::sparse::matrix_info>();
descrC = std::make_shared<dpct::sparse::matrix_info>();
descrA->set_index_base(oneapi::mkl::index_base::zero);
descrB->set_index_base(oneapi::mkl::index_base::zero);
descrC->set_index_base(oneapi::mkl::index_base::zero);
descrA->set_matrix_type(dpct::sparse::matrix_info::matrix_type::ge);
descrB->set_matrix_type(dpct::sparse::matrix_info::matrix_type::ge);
descrC->set_matrix_type(dpct::sparse::matrix_info::matrix_type::ge);

int c_nnz_s;
int c_nnz_d;
int c_nnz_c;
int c_nnz_z;
dpct::sparse::csrgemm_nnz(
handle, oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans, 3, 3, 4, descrA, 4, a_s_val.d_data,
a_row_ptr.d_data, a_col_ind.d_data, descrB, 5, b_s_val.d_data,
b_row_ptr.d_data, b_col_ind.d_data, descrC, c_s_row_ptr.d_data, &c_nnz_s);
dpct::sparse::csrgemm_nnz(
handle, oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans, 3, 3, 4, descrA, 4, a_d_val.d_data,
a_row_ptr.d_data, a_col_ind.d_data, descrB, 5, b_d_val.d_data,
b_row_ptr.d_data, b_col_ind.d_data, descrC, c_d_row_ptr.d_data, &c_nnz_d);
dpct::sparse::csrgemm_nnz(
handle, oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans, 3, 3, 4, descrA, 4, a_c_val.d_data,
a_row_ptr.d_data, a_col_ind.d_data, descrB, 5, b_c_val.d_data,
b_row_ptr.d_data, b_col_ind.d_data, descrC, c_c_row_ptr.d_data, &c_nnz_c);
dpct::sparse::csrgemm_nnz(
handle, oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans, 3, 3, 4, descrA, 4, a_z_val.d_data,
a_row_ptr.d_data, a_col_ind.d_data, descrB, 5, b_z_val.d_data,
b_row_ptr.d_data, b_col_ind.d_data, descrC, c_z_row_ptr.d_data, &c_nnz_z);

Data<float> c_s_val(c_nnz_s);
Data<double> c_d_val(c_nnz_d);
Data<sycl::float2> c_c_val(c_nnz_c);
Data<sycl::double2> c_z_val(c_nnz_z);
Data<int> c_s_col_ind(c_nnz_s);
Data<int> c_d_col_ind(c_nnz_d);
Data<int> c_c_col_ind(c_nnz_c);
Data<int> c_z_col_ind(c_nnz_z);

dpct::sparse::csrgemm(handle, oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans, 3, 3, 4, descrA,
a_s_val.d_data, a_row_ptr.d_data, a_col_ind.d_data,
descrB, b_s_val.d_data, b_row_ptr.d_data,
b_col_ind.d_data, descrC, c_s_val.d_data,
c_s_row_ptr.d_data, c_s_col_ind.d_data);
dpct::sparse::csrgemm(handle, oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans, 3, 3, 4, descrA,
a_d_val.d_data, a_row_ptr.d_data, a_col_ind.d_data,
descrB, b_d_val.d_data, b_row_ptr.d_data,
b_col_ind.d_data, descrC, c_d_val.d_data,
c_d_row_ptr.d_data, c_d_col_ind.d_data);
dpct::sparse::csrgemm(handle, oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans, 3, 3, 4, descrA,
a_c_val.d_data, a_row_ptr.d_data, a_col_ind.d_data,
descrB, b_c_val.d_data, b_row_ptr.d_data,
b_col_ind.d_data, descrC, c_c_val.d_data,
c_c_row_ptr.d_data, c_c_col_ind.d_data);
dpct::sparse::csrgemm(handle, oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans, 3, 3, 4, descrA,
a_z_val.d_data, a_row_ptr.d_data, a_col_ind.d_data,
descrB, b_z_val.d_data, b_row_ptr.d_data,
b_col_ind.d_data, descrC, c_z_val.d_data,
c_z_row_ptr.d_data, c_z_col_ind.d_data);

q_ct1.wait();

/*
DPCT1026:39: The call to cusparseDestroyMatDescr was removed because this
functionality is redundant in SYCL.
*/
/*
DPCT1026:40: The call to cusparseDestroyMatDescr was removed because this
functionality is redundant in SYCL.
*/
/*
DPCT1026:41: The call to cusparseDestroyMatDescr was removed because this
functionality is redundant in SYCL.
*/
delete (handle);

c_s_val.D2H();
c_d_val.D2H();
c_c_val.D2H();
c_z_val.D2H();
c_s_row_ptr.D2H();
c_d_row_ptr.D2H();
c_c_row_ptr.D2H();
c_z_row_ptr.D2H();
c_s_col_ind.D2H();
c_d_col_ind.D2H();
c_c_col_ind.D2H();
c_z_col_ind.D2H();

float expect_c_val[7] = {2.000000, 3.000000, 10.000000, 12.000000, 15.000000, 18.000000, 4.000000};
float expect_c_row_ptr[4] = {0.000000, 4.000000, 6.000000, 7.000000};
float expect_c_col_ind[7] = {0.000000, 1.000000, 2.000000, 3.000000, 2.000000, 3.000000, 0.000000};
if (compare_result(expect_c_val, c_s_val.h_data, 7) &&
compare_result(expect_c_val, c_d_val.h_data, 7) &&
compare_result(expect_c_val, c_c_val.h_data, 7) &&
compare_result(expect_c_val, c_z_val.h_data, 7) &&
compare_result(expect_c_row_ptr, c_s_row_ptr.h_data, 4) &&
compare_result(expect_c_row_ptr, c_d_row_ptr.h_data, 4) &&
compare_result(expect_c_row_ptr, c_c_row_ptr.h_data, 4) &&
compare_result(expect_c_row_ptr, c_z_row_ptr.h_data, 4) &&
compare_result(expect_c_col_ind, c_s_col_ind.h_data, 7) &&
compare_result(expect_c_col_ind, c_d_col_ind.h_data, 7) &&
compare_result(expect_c_col_ind, c_c_col_ind.h_data, 7) &&
compare_result(expect_c_col_ind, c_z_col_ind.h_data, 7)
)
printf("Tcsrgemm pass\n");
else {
printf("Tcsrgemm fail\n");
test_passed = false;
}
}

int main() {
test_cusparseSetGetStream();
test_cusparseTcsrmv_ge();
Expand All @@ -1979,6 +2177,7 @@ int main() {
test_cusparseCsrmvEx();
test_cusparseSpGEMM();
test_cusparseSpSV();
test_cusparseTcsrgemm();

if (test_passed)
return 0;
Expand Down
Loading

0 comments on commit 0a12c0d

Please sign in to comment.