Skip to content

Commit

Permalink
[HIPIFY][tests][fix] CUDA < 9.0 fixe
Browse files Browse the repository at this point in the history
+ `cublasHgemmBatched` appeared in CUDA 9.0, this needs #ifdef guards
  • Loading branch information
emankov committed Nov 11, 2023
1 parent 5f9b2f9 commit 777273a
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 22 deletions.
10 changes: 5 additions & 5 deletions tests/unit_tests/synthetic/libraries/cublas2hipblas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1112,11 +1112,6 @@ int main() {
// CHECK: blasStatus = hipblasDgemmBatched(blasHandle, transa, transb, m, n, k, &da, dAarray_const, lda, dBarray_const, ldb, &db, dCarray, ldc, batchCount);
blasStatus = cublasDgemmBatched(blasHandle, transa, transb, m, n, k, &da, dAarray_const, lda, dBarray_const, ldb, &db, dCarray, ldc, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* const Aarray[], int lda, const __half* const Barray[], int ldb, const __half* beta, __half* const Carray[], int ldc, int batchCount);
// HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasHgemmBatched(hipblasHandle_t handle, hipblasOperation_t transA, hipblasOperation_t transB, int m, int n, int k, const hipblasHalf* alpha, const hipblasHalf* const AP[], int lda, const hipblasHalf* const BP[], int ldb, const hipblasHalf* beta, hipblasHalf* const CP[], int ldc, int batchCount);
// CHECK: blasStatus = hipblasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
blasStatus = cublasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuComplex* alpha, const cuComplex* const Aarray[], int lda, const cuComplex* const Barray[], int ldb, const cuComplex* beta, cuComplex* const Carray[], int ldc, int batchCount);
// HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasCgemmBatched_v2(hipblasHandle_t handle, hipblasOperation_t transA, hipblasOperation_t transB, int m, int n, int k, const hipComplex* alpha, const hipComplex* const AP[], int lda, const hipComplex* const BP[], int ldb, const hipComplex* beta, hipComplex* const CP[], int ldc, int batchCount);
// CHECK: blasStatus = hipblasCgemmBatched_v2(blasHandle, transa, transb, m, n, k, &complexa, complexAarray_const, lda, complexBarray_const, ldb, &complexb, complexCarray, ldc, batchCount);
Expand Down Expand Up @@ -1614,6 +1609,11 @@ int main() {
#if CUDA_VERSION >= 9000
// CHECK: hipblasGemmAlgo_t BLAS_GEMM_DEFAULT = HIPBLAS_GEMM_DEFAULT;
cublasGemmAlgo_t BLAS_GEMM_DEFAULT = CUBLAS_GEMM_DEFAULT;

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* const Aarray[], int lda, const __half* const Barray[], int ldb, const __half* beta, __half* const Carray[], int ldc, int batchCount);
// HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasHgemmBatched(hipblasHandle_t handle, hipblasOperation_t transA, hipblasOperation_t transB, int m, int n, int k, const hipblasHalf* alpha, const hipblasHalf* const AP[], int lda, const hipblasHalf* const BP[], int ldb, const hipblasHalf* beta, hipblasHalf* const CP[], int ldc, int batchCount);
// CHECK: blasStatus = hipblasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
blasStatus = cublasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
#endif

#if CUDA_VERSION >= 9010 && CUDA_VERSION < 11000
Expand Down
10 changes: 5 additions & 5 deletions tests/unit_tests/synthetic/libraries/cublas2hipblas_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1243,11 +1243,6 @@ int main() {
// CHECK: blasStatus = hipblasDgemmBatched(blasHandle, transa, transb, m, n, k, &da, dAarray_const, lda, dBarray_const, ldb, &db, dCarray, ldc, batchCount);
blasStatus = cublasDgemmBatched(blasHandle, transa, transb, m, n, k, &da, dAarray_const, lda, dBarray_const, ldb, &db, dCarray, ldc, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* const Aarray[], int lda, const __half* const Barray[], int ldb, const __half* beta, __half* const Carray[], int ldc, int batchCount);
// HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasHgemmBatched(hipblasHandle_t handle, hipblasOperation_t transA, hipblasOperation_t transB, int m, int n, int k, const hipblasHalf* alpha, const hipblasHalf* const AP[], int lda, const hipblasHalf* const BP[], int ldb, const hipblasHalf* beta, hipblasHalf* const CP[], int ldc, int batchCount);
// CHECK: blasStatus = hipblasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
blasStatus = cublasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuComplex* alpha, const cuComplex* const Aarray[], int lda, const cuComplex* const Barray[], int ldb, const cuComplex* beta, cuComplex* const Carray[], int ldc, int batchCount);
// HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasCgemmBatched_v2(hipblasHandle_t handle, hipblasOperation_t transA, hipblasOperation_t transB, int m, int n, int k, const hipComplex* alpha, const hipComplex* const AP[], int lda, const hipComplex* const BP[], int ldb, const hipComplex* beta, hipComplex* const CP[], int ldc, int batchCount);
// CHECK: blasStatus = hipblasCgemmBatched_v2(blasHandle, transa, transb, m, n, k, &complexa, complexAarray_const, lda, complexBarray_const, ldb, &complexb, complexCarray, ldc, batchCount);
Expand Down Expand Up @@ -1769,6 +1764,11 @@ int main() {
#if CUDA_VERSION >= 9000
// CHECK: hipblasGemmAlgo_t BLAS_GEMM_DEFAULT = HIPBLAS_GEMM_DEFAULT;
cublasGemmAlgo_t BLAS_GEMM_DEFAULT = CUBLAS_GEMM_DEFAULT;

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* const Aarray[], int lda, const __half* const Barray[], int ldb, const __half* beta, __half* const Carray[], int ldc, int batchCount);
// HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasHgemmBatched(hipblasHandle_t handle, hipblasOperation_t transA, hipblasOperation_t transB, int m, int n, int k, const hipblasHalf* alpha, const hipblasHalf* const AP[], int lda, const hipblasHalf* const BP[], int ldb, const hipblasHalf* beta, hipblasHalf* const CP[], int ldc, int batchCount);
// CHECK: blasStatus = hipblasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
blasStatus = cublasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
#endif

#if CUDA_VERSION >= 9010 && CUDA_VERSION < 11000
Expand Down
12 changes: 6 additions & 6 deletions tests/unit_tests/synthetic/libraries/cublas2rocblas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1271,12 +1271,6 @@ int main() {
// CHECK: blasStatus = rocblas_dgemm_batched(blasHandle, transa, transb, m, n, k, &da, dAarray_const, lda, dBarray_const, ldb, &db, dCarray, ldc, batchCount);
blasStatus = cublasDgemmBatched(blasHandle, transa, transb, m, n, k, &da, dAarray_const, lda, dBarray_const, ldb, &db, dCarray, ldc, batchCount);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* const Aarray[], int lda, const __half* const Barray[], int ldb, const __half* beta, __half* const Carray[], int ldc, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hgemm_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_half* alpha, const rocblas_half* const A[], rocblas_int lda, const rocblas_half* const B[], rocblas_int ldb, const rocblas_half* beta, rocblas_half* const C[], rocblas_int ldc, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_hgemm_batched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
blasStatus = cublasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuComplex* alpha, const cuComplex* const Aarray[], int lda, const cuComplex* const Barray[], int ldb, const cuComplex* beta, cuComplex* const Carray[], int ldc, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_cgemm_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_float_complex* alpha, const rocblas_float_complex* const A[], rocblas_int lda, const rocblas_float_complex* const B[], rocblas_int ldb, const rocblas_float_complex* beta, rocblas_float_complex* const C[], rocblas_int ldc, rocblas_int batch_count);
Expand Down Expand Up @@ -1739,6 +1733,12 @@ int main() {
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_set_math_mode(rocblas_handle handle, rocblas_math_mode math_mode);
// CHECK: blasStatus = rocblas_set_math_mode(blasHandle, blasMath);
blasStatus = cublasSetMathMode(blasHandle, blasMath);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* const Aarray[], int lda, const __half* const Barray[], int ldb, const __half* beta, __half* const Carray[], int ldc, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hgemm_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_half* alpha, const rocblas_half* const A[], rocblas_int lda, const rocblas_half* const B[], rocblas_int ldb, const rocblas_half* beta, rocblas_half* const C[], rocblas_int ldc, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_hgemm_batched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
blasStatus = cublasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
#endif

#if CUDA_VERSION >= 9010 && CUDA_VERSION < 11000
Expand Down
12 changes: 6 additions & 6 deletions tests/unit_tests/synthetic/libraries/cublas2rocblas_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1394,12 +1394,6 @@ int main() {
// CHECK: blasStatus = rocblas_dgemm_batched(blasHandle, transa, transb, m, n, k, &da, dAarray_const, lda, dBarray_const, ldb, &db, dCarray, ldc, batchCount);
blasStatus = cublasDgemmBatched(blasHandle, transa, transb, m, n, k, &da, dAarray_const, lda, dBarray_const, ldb, &db, dCarray, ldc, batchCount);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* const Aarray[], int lda, const __half* const Barray[], int ldb, const __half* beta, __half* const Carray[], int ldc, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hgemm_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_half* alpha, const rocblas_half* const A[], rocblas_int lda, const rocblas_half* const B[], rocblas_int ldb, const rocblas_half* beta, rocblas_half* const C[], rocblas_int ldc, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_hgemm_batched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
blasStatus = cublasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuComplex* alpha, const cuComplex* const Aarray[], int lda, const cuComplex* const Barray[], int ldb, const cuComplex* beta, cuComplex* const Carray[], int ldc, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_cgemm_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_float_complex* alpha, const rocblas_float_complex* const A[], rocblas_int lda, const rocblas_float_complex* const B[], rocblas_int ldb, const rocblas_float_complex* beta, rocblas_float_complex* const C[], rocblas_int ldc, rocblas_int batch_count);
Expand Down Expand Up @@ -1871,6 +1865,12 @@ int main() {
#if CUDA_VERSION >= 9000
// CHECK: rocblas_gemm_algo BLAS_GEMM_DEFAULT = rocblas_gemm_algo_standard;
cublasGemmAlgo_t BLAS_GEMM_DEFAULT = CUBLAS_GEMM_DEFAULT;

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* const Aarray[], int lda, const __half* const Barray[], int ldb, const __half* beta, __half* const Carray[], int ldc, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hgemm_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_half* alpha, const rocblas_half* const A[], rocblas_int lda, const rocblas_half* const B[], rocblas_int ldb, const rocblas_half* beta, rocblas_half* const C[], rocblas_int ldc, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_hgemm_batched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
blasStatus = cublasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
#endif

#if CUDA_VERSION >= 9010 && CUDA_VERSION < 11000
Expand Down

0 comments on commit 777273a

Please sign in to comment.