diff --git a/paddle/fluid/platform/dynload/cusparse.h b/paddle/fluid/platform/dynload/cusparse.h index 74f9b973a388c..aa6f63ba75c5e 100644 --- a/paddle/fluid/platform/dynload/cusparse.h +++ b/paddle/fluid/platform/dynload/cusparse.h @@ -30,26 +30,33 @@ namespace dynload { #if defined(PADDLE_WITH_CUDA) #if CUDA_VERSION >= 11000 -#define CUSPARSE_ROUTINE_EACH(__macro) \ - __macro(cusparseCreate); \ - __macro(cusparseSetStream); \ - __macro(cusparseCreateMatDescr); \ - __macro(cusparseDestroy); \ - __macro(cusparseSnnz); \ - __macro(cusparseDnnz); \ - __macro(cusparseSetMatType); \ - __macro(cusparseSetMatIndexBase); \ - __macro(cusparseCreateCsr); \ - __macro(cusparseCreateCoo); \ - __macro(cusparseCreateDnMat); \ - __macro(cusparseCreateDnVec); \ - __macro(cusparseSpMM_bufferSize); \ - __macro(cusparseSpMM); \ - __macro(cusparseDestroySpMat); \ - __macro(cusparseDestroyDnMat); \ - __macro(cusparseDestroyDnVec); \ - __macro(cusparseSpMV_bufferSize); \ - __macro(cusparseSpMV); +#define CUSPARSE_ROUTINE_EACH(__macro) \ + __macro(cusparseCreate); \ + __macro(cusparseSetStream); \ + __macro(cusparseCreateMatDescr); \ + __macro(cusparseDestroy); \ + __macro(cusparseSnnz); \ + __macro(cusparseDnnz); \ + __macro(cusparseSetMatType); \ + __macro(cusparseSetMatIndexBase); \ + __macro(cusparseCreateCsr); \ + __macro(cusparseCreateCoo); \ + __macro(cusparseCreateDnMat); \ + __macro(cusparseCreateDnVec); \ + __macro(cusparseSpMM_bufferSize); \ + __macro(cusparseSpMM); \ + __macro(cusparseDestroySpMat); \ + __macro(cusparseDestroyDnMat); \ + __macro(cusparseDestroyDnVec); \ + __macro(cusparseSpMV_bufferSize); \ + __macro(cusparseSpMV); \ + __macro(cusparseSpGEMM_createDescr); \ + __macro(cusparseSpGEMM_workEstimation); \ + __macro(cusparseSpGEMM_compute); \ + __macro(cusparseSpGEMM_destroyDescr); \ + __macro(cusparseSpMatGetSize); \ + __macro(cusparseCsrSetPointers); \ + __macro(cusparseSpGEMM_copy); CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #endif diff --git a/paddle/phi/backends/dynload/cusparse.h b/paddle/phi/backends/dynload/cusparse.h index fcbabd55b7ebb..5a52ef5e4065e 100644 --- a/paddle/phi/backends/dynload/cusparse.h +++ b/paddle/phi/backends/dynload/cusparse.h @@ -42,26 +42,33 @@ extern void *cusparse_dso_handle; #if defined(PADDLE_WITH_CUDA) #if CUDA_VERSION >= 11000 -#define CUSPARSE_ROUTINE_EACH(__macro) \ - __macro(cusparseCreate); \ - __macro(cusparseSetStream); \ - __macro(cusparseCreateMatDescr); \ - __macro(cusparseDestroy); \ - __macro(cusparseSnnz); \ - __macro(cusparseDnnz); \ - __macro(cusparseSetMatType); \ - __macro(cusparseSetMatIndexBase); \ - __macro(cusparseCreateCsr); \ - __macro(cusparseCreateCoo); \ - __macro(cusparseCreateDnMat); \ - __macro(cusparseCreateDnVec); \ - __macro(cusparseSpMM_bufferSize); \ - __macro(cusparseSpMM); \ - __macro(cusparseDestroySpMat); \ - __macro(cusparseDestroyDnMat); \ - __macro(cusparseDestroyDnVec); \ - __macro(cusparseSpMV_bufferSize); \ - __macro(cusparseSpMV); +#define CUSPARSE_ROUTINE_EACH(__macro) \ + __macro(cusparseCreate); \ + __macro(cusparseSetStream); \ + __macro(cusparseCreateMatDescr); \ + __macro(cusparseDestroy); \ + __macro(cusparseSnnz); \ + __macro(cusparseDnnz); \ + __macro(cusparseSetMatType); \ + __macro(cusparseSetMatIndexBase); \ + __macro(cusparseCreateCsr); \ + __macro(cusparseCreateCoo); \ + __macro(cusparseCreateDnMat); \ + __macro(cusparseCreateDnVec); \ + __macro(cusparseSpMM_bufferSize); \ + __macro(cusparseSpMM); \ + __macro(cusparseDestroySpMat); \ + __macro(cusparseDestroyDnMat); \ + __macro(cusparseDestroyDnVec); \ + __macro(cusparseSpMV_bufferSize); \ + __macro(cusparseSpMV); \ + __macro(cusparseSpGEMM_createDescr); \ + __macro(cusparseSpGEMM_workEstimation); \ + __macro(cusparseSpGEMM_compute); \ + __macro(cusparseSpGEMM_destroyDescr); \ + __macro(cusparseSpMatGetSize); \ + __macro(cusparseCsrSetPointers); \ + __macro(cusparseSpGEMM_copy); CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #endif diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas.h b/paddle/phi/kernels/funcs/sparse/sparse_blas.h index f6d67488d1f48..228fb9b0b0b4b 100644 --- a/paddle/phi/kernels/funcs/sparse/sparse_blas.h +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas.h @@ -28,6 +28,15 @@ class SparseBlas { public: explicit SparseBlas(const DeviceContext& dev_ctx) : dev_ctx_(dev_ctx) {} + template + void SPMM(bool transa, + bool transb, + T alpha, + const TensorType& mat_a, + const TensorType& mat_b, + T beta, + TensorType* mat_out) const; + template void SPMM(bool transa, bool transb, diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h index 3502dbfc9ceda..8a13113e966a9 100644 --- a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h @@ -23,6 +23,7 @@ #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/visit_type.h" @@ -89,6 +90,8 @@ inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x, int64_t batch_nnz = x.nnz() / batch_size; cudaDataType_t gpu_type = GetGpuDataType(); + cusparseIndexType_t index_type = + std::is_same::value ? CUSPARSE_INDEX_32I : CUSPARSE_INDEX_64I; dev_ctx.CusparseCall([&](cusparseHandle_t handle) { phi::dynload::cusparseCreateCsr(descriptor, M, @@ -97,8 +100,8 @@ inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x, const_cast(crows_data), const_cast(cols_data), const_cast(values_data), - CUSPARSE_INDEX_64I, - CUSPARSE_INDEX_64I, + index_type, + index_type, CUSPARSE_INDEX_BASE_ZERO, gpu_type); }); @@ -309,6 +312,32 @@ class CuSparseDnVecDescriptor { cusparseDnVecDescr_t descriptor_; }; +/************* SpGEMM DESCRIPTOR ************/ +template +class CuSparseSpGEMMDescriptor { + public: + explicit CuSparseSpGEMMDescriptor(const phi::GPUContext& dev_ctx) + : dev_ctx_(dev_ctx) { + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSpGEMM_createDescr(&descriptor_); + }); + VLOG(6) << "Create cusparseSpGEMMDescr_t " << &descriptor_; + } + + ~CuSparseSpGEMMDescriptor() { + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSpGEMM_destroyDescr(descriptor_); + }); + VLOG(6) << "Destroy cusparseSpGEMMDescr_t " << &descriptor_; + } + + const cusparseSpGEMMDescr_t& descriptor() const { return descriptor_; } + + private: + const phi::GPUContext& dev_ctx_; + cusparseSpGEMMDescr_t descriptor_; +}; + /************* SPARSE*DENSE->DENSE MATMUL ************/ template <> template @@ -414,6 +443,142 @@ void SparseBlas::SPMV(bool transa, }); } +/************* SPARSE*SPARSE->SPARSE MATMUL ************/ +template <> +template +void SparseBlas::SPMM(bool transa, + bool transb, + T alpha, + const TensorType& mat_a, + const TensorType& mat_b, + T beta, + TensorType* mat_out) const { + auto dims = mat_out->dims(); + DenseTensor *mat_out_crows = mat_out->mutable_crows(), + *mat_out_cols = mat_out->mutable_cols(), + *mat_out_values = mat_out->mutable_values(); + MetaTensor meta_out_crows(mat_out_crows), meta_out_cols(mat_out_cols), + meta_out_values(mat_out_values); + meta_out_crows.set_dtype(mat_a.crows().dtype()); + meta_out_cols.set_dtype(mat_a.cols().dtype()); + meta_out_values.set_dtype(mat_a.values().dtype()); + meta_out_crows.set_dims(common::make_ddim({dims[dims.size() - 2] + 1})); + int* out_crows = dev_ctx_.template Alloc(mat_out_crows); + dev_ctx_.template Alloc(mat_out_cols); + dev_ctx_.template Alloc(mat_out_values); + + auto a_descriptor = CuSparseSpMatDescriptor(mat_a, dev_ctx_); + auto b_descriptor = CuSparseSpMatDescriptor(mat_b, dev_ctx_); + auto out_descriptor = CuSparseSpMatDescriptor(*mat_out, dev_ctx_); + auto spgemm_descriptor = CuSparseSpGEMMDescriptor(dev_ctx_); + + cudaDataType_t gpu_type = GetGpuDataType(); + size_t buffer_size1 = 0, buffer_size2 = 0; + + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSpGEMM_workEstimation(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + out_descriptor.descriptor(), + gpu_type, + CUSPARSE_SPGEMM_DEFAULT, + spgemm_descriptor.descriptor(), + &buffer_size1, + nullptr); + }); + phi::Allocator::AllocationPtr tmp_buffer1 = phi::memory_utils::Alloc( + dev_ctx_.GetPlace(), + buffer_size1, + phi::Stream(reinterpret_cast(dev_ctx_.stream()))); + void* tmp_buffer_ptr1 = tmp_buffer1->ptr(); + + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSpGEMM_workEstimation(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + out_descriptor.descriptor(), + gpu_type, + CUSPARSE_SPGEMM_DEFAULT, + spgemm_descriptor.descriptor(), + &buffer_size1, + tmp_buffer_ptr1); + }); + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSpGEMM_compute(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + out_descriptor.descriptor(), + gpu_type, + CUSPARSE_SPGEMM_DEFAULT, + spgemm_descriptor.descriptor(), + &buffer_size2, + nullptr); + }); + phi::Allocator::AllocationPtr tmp_buffer2 = phi::memory_utils::Alloc( + dev_ctx_.GetPlace(), + buffer_size2, + phi::Stream(reinterpret_cast(dev_ctx_.stream()))); + void* tmp_buffer_ptr2 = tmp_buffer2->ptr(); + + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSpGEMM_compute(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + out_descriptor.descriptor(), + gpu_type, + CUSPARSE_SPGEMM_DEFAULT, + spgemm_descriptor.descriptor(), + &buffer_size2, + tmp_buffer_ptr2); + }); + + int64_t C_num_rows1, C_num_cols1, C_nnz1; + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSpMatGetSize( + out_descriptor.descriptor(), &C_num_rows1, &C_num_cols1, &C_nnz1); + }); + + meta_out_cols.set_dims(common::make_ddim({C_nnz1})); + meta_out_values.set_dims(common::make_ddim({C_nnz1})); + T* out_values = dev_ctx_.template Alloc(mat_out_values); + int* out_cols = dev_ctx_.template Alloc(mat_out_cols); + + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseCsrSetPointers( + out_descriptor.descriptor(), out_crows, out_cols, out_values); + }); + + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSpGEMM_copy(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + out_descriptor.descriptor(), + gpu_type, + CUSPARSE_SPGEMM_DEFAULT, + spgemm_descriptor.descriptor()); + }); +} + /************* DENSE*DENSE->SPARSE MATMUL ************/ #if CUDA_VERSION >= 11030 template <> diff --git a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu index 5878b6662f877..a73229b2d9229 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" #include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" +#include "paddle/phi/kernels/sparse/unary_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h" namespace phi { @@ -139,6 +140,84 @@ void MatmulCsrDenseGradKernel(const Context& dev_ctx, #endif } +template +void MatmulCsrCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + const SparseCsrTensor& dout, + SparseCsrTensor* dx, + SparseCsrTensor* dy) { +#if CUDA_VERSION >= 11000 + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); + SparseCsrTensor tmp_dout; + CastCsrKernel( + dev_ctx, dout, phi::DataType::INT32, dout.values().dtype(), &tmp_dout); + // dx{SparseCsr} = dout{SparseCsr} * y'{SparseCsr} + if (dx) { + auto dims_numel = y.dims().size(); + SparseCsrTensor transpose_y, tmp_y; + if (dims_numel == 2) { + TransposeCsrKernel(dev_ctx, y, {1, 0}, &transpose_y); + } else { + TransposeCsrKernel(dev_ctx, y, {0, 2, 1}, &transpose_y); + } + CastCsrKernel( + dev_ctx, transpose_y, phi::DataType::INT32, y.values().dtype(), &tmp_y); + + sparse_blas.SPMM(false, + false, + static_cast(1), + tmp_dout, + tmp_y, + static_cast(0), + dx); + } + + // dy{SparseCsr} = x'{SparseCsr} * dout{SparseCsr} + if (dy) { + auto dims_numel = x.dims().size(); + SparseCsrTensor transpose_x, tmp_x; + if (dims_numel == 2) { + TransposeCsrKernel(dev_ctx, x, {1, 0}, &transpose_x); + } else { + TransposeCsrKernel(dev_ctx, x, {0, 2, 1}, &transpose_x); + } + CastCsrKernel( + dev_ctx, transpose_x, phi::DataType::INT32, x.values().dtype(), &tmp_x); + sparse_blas.SPMM(false, + false, + static_cast(1), + tmp_x, + tmp_dout, + static_cast(0), + dy); + } +#else + PADDLE_THROW(phi::errors::Unimplemented( + "backward of 'sparse.matmul' use cusparseSPGEMM, which is supported from " + "CUDA 11.0")); +#endif +} + +template +void MatmulCooCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + const SparseCooTensor& dout, + SparseCooTensor* dx, + SparseCooTensor* dy) { + // 'cusparseSPGEMM' only support CSR now, so use COO->CSR->COO, + SparseCsrTensor x_csr = CooToCsr(dev_ctx, x); + SparseCsrTensor y_csr = CooToCsr(dev_ctx, y); + SparseCsrTensor dout_csr = CooToCsr(dev_ctx, dout); + SparseCsrTensor dx_csr, dy_csr; + dx_csr.set_dims(dx->dims()); + dy_csr.set_dims(dy->dims()); + MatmulCsrCsrGradKernel(dev_ctx, x_csr, y_csr, dout_csr, &dx_csr, &dy_csr); + CsrToCooKernel(dev_ctx, dx_csr, dx); + CsrToCooKernel(dev_ctx, dy_csr, dy); +} + template void MaskedMatmulCsrGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -217,3 +296,23 @@ PD_REGISTER_KERNEL(masked_matmul_csr_grad, phi::sparse::MaskedMatmulCsrGradKernel, float, double) {} + +PD_REGISTER_KERNEL(matmul_csr_csr_grad, + GPU, + ALL_LAYOUT, + phi::sparse::MatmulCsrCsrGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(matmul_coo_coo_grad, + GPU, + ALL_LAYOUT, + phi::sparse::MatmulCooCooGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu index 9a808f5ddcc0b..bb09d7853e594 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu @@ -28,6 +28,8 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/math_function_impl.h" #include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" +#include "paddle/phi/kernels/sparse/unary_kernel.h" namespace phi { namespace sparse { @@ -201,6 +203,79 @@ void MaskedMatmulCsrKernel(const Context& dev_ctx, #endif } +template +void MatmulCsrCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + SparseCsrTensor* out) { +#if CUDA_VERSION >= 11000 + std::vector xdim_vec = common::vectorize(x.dims()); + std::vector ydim_vec = common::vectorize(y.dims()); + auto x_ndims = xdim_vec.size(); + auto y_ndims = ydim_vec.size(); + PADDLE_ENFORCE_EQ( + x_ndims, + y_ndims, + phi::errors::PreconditionNotMet("The dims size of Input(x) and Input(y) " + "should be equal, But received X's " + "dimensions=%d, Y's dimensions=%d.", + x_ndims, + y_ndims)); + PADDLE_ENFORCE_GE( + x_ndims, + 2, + phi::errors::InvalidArgument("the dims size of Input(x) and " + "Input(y) must be greater than " + "or eaqual to 2.")); + + for (size_t i = 0; i < x_ndims - 2; ++i) { + PADDLE_ENFORCE_EQ(xdim_vec[i], + ydim_vec[i], + phi::errors::InvalidArgument( + "x.dim[%d] and x.dim[%d] must be eaqul.", i, i)); + } + + PADDLE_ENFORCE_GE( + xdim_vec[x_ndims - 1], + ydim_vec[y_ndims - 2], + phi::errors::PreconditionNotMet( + "The shape of Input(x) and Input(y) is not suitable for matmul " + "opetation, x_dim[-1] must be eaqual to y_dim[-2].")); + + std::vector out_dim_vec(ydim_vec); + out_dim_vec[y_ndims - 2] = xdim_vec[x_ndims - 2]; + out_dim_vec[y_ndims - 1] = ydim_vec[y_ndims - 1]; + + out->set_dims(common::make_ddim(out_dim_vec)); + SparseCsrTensor x_tmp, y_tmp; + CastCsrKernel( + dev_ctx, x, phi::DataType::INT32, x.values().dtype(), &x_tmp); + CastCsrKernel( + dev_ctx, y, phi::DataType::INT32, y.values().dtype(), &y_tmp); + + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); + sparse_blas.SPMM( + false, false, static_cast(1), x_tmp, y_tmp, static_cast(0), out); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "forward of 'sparse.matmul' use cusparseSpGEMM, " + "which is supported from CUDA 11.0")); +#endif +} + +template +void MatmulCooCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + SparseCooTensor* out) { + SparseCsrTensor x_csr = CooToCsr(dev_ctx, x); + SparseCsrTensor y_csr = CooToCsr(dev_ctx, y); + SparseCsrTensor out_csr; + out_csr.set_dims(out->dims()); + MatmulCsrCsrKernel(dev_ctx, x_csr, y_csr, &out_csr); + CsrToCooKernel(dev_ctx, out_csr, out); +} + } // namespace sparse } // namespace phi @@ -228,3 +303,23 @@ PD_REGISTER_KERNEL(masked_matmul_csr, phi::sparse::MaskedMatmulCsrKernel, float, double) {} + +PD_REGISTER_KERNEL(matmul_csr_csr, + GPU, + ALL_LAYOUT, + phi::sparse::MatmulCsrCsrKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(matmul_coo_coo, + GPU, + ALL_LAYOUT, + phi::sparse::MatmulCooCooKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h index f0ea90ee1f09b..131f0ccc8e7d4 100644 --- a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h +++ b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h @@ -202,6 +202,7 @@ void CastCsrKernel(const Context& dev_ctx, meta.set_dims(x_values.dims()); phi::CastKernel(dev_ctx, x_values, value_dtype, out_values); } + out->set_dims(x.dims()); } template diff --git a/test/legacy_test/test_sparse_matmul_op.py b/test/legacy_test/test_sparse_matmul_op.py index eb608dd379cac..87015a00dd281 100644 --- a/test/legacy_test/test_sparse_matmul_op.py +++ b/test/legacy_test/test_sparse_matmul_op.py @@ -170,5 +170,61 @@ def test_masked_matmul_3d(self): ) +class TestMatmulSparseSparse(unittest.TestCase): + # x: sparse, y: sparse, out: sparse + def check_result(self, x_shape, y_shape, sparse): + mask = paddle.randint(0, 2, x_shape) + origin_x = paddle.rand(x_shape) * mask + origin_y = paddle.rand(y_shape) + + dense_x = origin_x.detach() + dense_x.stop_gradient = False + dense_y = origin_y.detach() + dense_y.stop_gradient = False + dense_out = paddle.matmul(dense_x, dense_y) + if sparse == 'csr': + sp_x = origin_x.detach().to_sparse_csr() + sp_y = origin_y.detach().to_sparse_csr() + else: + sp_x = origin_x.detach().to_sparse_coo(len(x_shape)) + sp_y = origin_y.detach().to_sparse_coo(len(y_shape)) + + sp_x.stop_gradient = False + sp_y.stop_gradient = False + + sp_out = paddle.sparse.matmul(sp_x, sp_y) + + np.testing.assert_allclose( + sp_out.to_dense().numpy(), dense_out.numpy(), rtol=1e-05 + ) + + dense_out.backward() + sp_out.backward() + np.testing.assert_allclose( + sp_x.grad.to_dense().numpy(), + dense_x.grad.numpy(), + rtol=1e-05, + ) + np.testing.assert_allclose( + sp_y.grad.to_dense().numpy(), dense_y.grad.numpy(), rtol=1e-05 + ) + + @unittest.skipIf( + not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000, + "only support cuda>=11.0", + ) + def test_matmul_2d(self): + self.check_result([16, 12], [12, 10], 'csr') + self.check_result([16, 12], [12, 10], 'coo') + + @unittest.skipIf( + not paddle.is_compiled_with_cuda() or get_cuda_version() < 11080, + "only support cuda>=11.8", + ) + def test_matmul_3d(self): + self.check_result([2, 16, 12], [2, 12, 10], 'csr') + self.check_result([2, 16, 12], [2, 12, 10], 'coo') + + if __name__ == "__main__": unittest.main()