diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index 143b4867b3..af4792b832 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -681,16 +681,28 @@ std::string RunLayerContext::getKernelName(LayerKernel layerKernel) { switch (layerKernel) { case LayerKernel::SGEMV: return "sgemv_cl"; - case LayerKernel::DOT: - return "dot_cl"; - case LayerKernel::SGEMM: - return "sgemm_cl"; case LayerKernel::SGEMV_FP16: return "sgemv_cl_fp16"; + case LayerKernel::DOT: + return "dot_cl"; case LayerKernel::DOT_FP16: return "dot_cl_fp16"; - case LayerKernel::SGEMM_FP16: - return "sgemm_cl_fp16"; + case LayerKernel::SGEMM_NOTRANS: + return "sgemm_cl_noTrans"; + case LayerKernel::SGEMM_NOTRANS_FP16: + return "sgemm_cl_noTrans_fp16"; + case LayerKernel::SGEMM_TRANSA: + return "sgemm_cl_transA"; + case LayerKernel::SGEMM_TRANSA_FP16: + return "sgemm_cl_transA_fp16"; + case LayerKernel::SGEMM_TRANSB: + return "sgemm_cl_transB"; + case LayerKernel::SGEMM_TRANSB_FP16: + return "sgemm_cl_transB_fp16"; + case LayerKernel::SGEMM_TRANSAB: + return "sgemm_cl_transAB"; + case LayerKernel::SGEMM_TRANSAB_FP16: + return "sgemm_cl_transAB_fp16"; case LayerKernel::ADD: return "addition_cl"; case LayerKernel::ADD_FP16: diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index 9a621dae1f..50b291f622 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -830,18 +830,24 @@ class RunLayerContext { * getKernelName function. */ enum LayerKernel { - SGEMV = 1 << 0, /**< placeholder for kernel name */ - DOT = 1 << 1, /**< placeholder for kernel name */ - SGEMM = 1 << 2, /**< placeholder for kernel name */ - SGEMV_FP16 = 1 << 3, /**< placeholder for kernel name */ - DOT_FP16 = 1 << 4, /**< placeholder for kernel name */ - SGEMM_FP16 = 1 << 5, /**< placeholder for kernel name */ - ADD = 1 << 6, /**< placeholder for kernel name */ - ADD_FP16 = 1 << 7, /**< placeholder for kernel name */ - SWIGLU = 1 << 8, /**< placeholder for kernel name */ - SWIGLU_FP16 = 1 << 9, /**< placeholder for kernel name */ - SSCAL = 1 << 10, /**< placeholder for kernel name */ - SSCAL_FP16 = 1 << 11, /**< placeholder for kernel name */ + SGEMV = 1 << 0, /**< placeholder for kernel name */ + SGEMV_FP16 = 1 << 1, /**< placeholder for kernel name */ + DOT = 1 << 2, /**< placeholder for kernel name */ + DOT_FP16 = 1 << 3, /**< placeholder for kernel name */ + SGEMM_NOTRANS = 1 << 4, /**< placeholder for kernel name */ + SGEMM_NOTRANS_FP16 = 1 << 5, /**< placeholder for kernel name */ + SGEMM_TRANSA = 1 << 6, /**< placeholder for kernel name */ + SGEMM_TRANSA_FP16 = 1 << 7, /**< placeholder for kernel name */ + SGEMM_TRANSB = 1 << 8, /**< placeholder for kernel name */ + SGEMM_TRANSB_FP16 = 1 << 9, /**< placeholder for kernel name */ + SGEMM_TRANSAB = 1 << 10, /**< placeholder for kernel name */ + SGEMM_TRANSAB_FP16 = 1 << 11, /**< placeholder for kernel name */ + ADD = 1 << 12, /**< placeholder for kernel name */ + ADD_FP16 = 1 << 13, /**< placeholder for kernel name */ + SWIGLU = 1 << 14, /**< placeholder for kernel name */ + SWIGLU_FP16 = 1 << 15, /**< placeholder for kernel name */ + SSCAL = 1 << 16, /**< placeholder for kernel name */ + SSCAL_FP16 = 1 << 17, /**< placeholder for kernel name */ }; /** diff --git a/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp b/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp index c0c98019d5..8ee1b5e426 100644 --- a/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp @@ -147,9 +147,8 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result, } /// case others: use gemm else { - // transA == false, transB == false - sgemm_cl(data, mdata, rdata, M, N, K, lda, ldb, ldc, context); - // todo: other condition implementations + sgemm_cl(transA, transB, data, mdata, rdata, M, N, K, lda, ldb, ldc, + context); } } else if (input.getDataType() == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 @@ -184,9 +183,8 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result, } /// case others: use sgemm else { - // transA == false, transB == false - sgemm_cl(data, mdata, rdata, M, N, K, lda, ldb, ldc, context); - // todo: other condition implementations + sgemm_cl(transA, transB, data, mdata, rdata, M, N, K, lda, ldb, ldc, + context); } #else throw std::invalid_argument("Error: enable-fp16 is not enabled"); diff --git a/nntrainer/tensor/cl_operations/blas_kernels.cpp b/nntrainer/tensor/cl_operations/blas_kernels.cpp index 3d459232dc..791cdc5e6b 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernels.cpp @@ -35,8 +35,8 @@ std::string dot_cl_kernel_ = } })"; -std::string sgemm_cl_kernel_ = - R"(__kernel void sgemm_cl(const __global float* A, const __global float* B, +std::string sgemm_cl_noTrans_kernel_ = + R"(__kernel void sgemm_cl_noTrans(const __global float* A, const __global float* B, __global float* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) { unsigned int m = get_global_id(0); @@ -51,6 +51,58 @@ std::string sgemm_cl_kernel_ = C[m * ldc + n] = c; })"; +std::string sgemm_cl_transA_kernel_ = + R"(__kernel void sgemm_cl_transA(const __global float* A, const __global float* B, + __global float* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) { + + unsigned int m = get_global_id(0); + unsigned int n = get_global_id(1); + float c = 0.0f; + for (unsigned int k = 0; k < K; ++k) { + float a, b; + a = A[k * lda + m]; + b = B[k * ldb + n]; + c += a * b; + } + C[m * ldc + n] = c; + })"; + +std::string sgemm_cl_transB_kernel_ = + R"(__kernel void sgemm_cl_transB(const __global float *A, const __global float *B, + __global float *C, unsigned int K, + unsigned int lda, unsigned int ldb, + unsigned int ldc) { + + unsigned int m = get_global_id(0); + unsigned int n = get_global_id(1); + float c = 0.0f; + for (unsigned int k = 0; k < K; ++k) { + float a, b; + a = A[m * lda + k]; + b = B[n * ldb + k]; + c += a * b; + } + C[m * ldc + n] = c; + })"; + +std::string sgemm_cl_transAB_kernel_ = + R"(__kernel void sgemm_cl_transAB(const __global float *A, const __global float *B, + __global float *C, unsigned int K, + unsigned int lda, unsigned int ldb, + unsigned int ldc) { + + unsigned int m = get_global_id(0); + unsigned int n = get_global_id(1); + float c = 0.0f; + for (unsigned int k = 0; k < K; ++k) { + float a, b; + a = A[k * lda + m]; + b = B[n * ldb + k]; + c += a * b; + } + C[m * ldc + n] = c; + })"; + std::string addition_cl_kernel_ = R"(__kernel void addition_cl(__global const float* input, __global float* output, const unsigned int size) { #pragma printf_support @@ -71,7 +123,10 @@ std::string sscal_cl_kernel_ = * @brief defining global kernel objects */ opencl::Kernel kernel_sgemv; -opencl::Kernel kernel_sgemm; +opencl::Kernel kernel_sgemm_transAB; +opencl::Kernel kernel_sgemm_transA; +opencl::Kernel kernel_sgemm_transB; +opencl::Kernel kernel_sgemm_noTrans; opencl::Kernel kernel_dot; opencl::Kernel kernel_addition; opencl::Kernel kernel_sscal; @@ -227,19 +282,43 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1, return cl_ret; } -void sgemm_cl(const float *A, const float *B, float *C, unsigned int M, - unsigned int N, unsigned int K, unsigned int lda, - unsigned int ldb, unsigned int ldc, RunLayerContext &context) { +void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const float *A, + const float *B, float *C, unsigned int M, unsigned int N, + unsigned int K, unsigned int lda, unsigned int ldb, + unsigned int ldc, RunLayerContext &context) { + + opencl::Kernel *kernel_sgemm = nullptr; + RunLayerContext::LayerKernel layerKernel; + std::string sgemm_cl_kernel_; + + if (TransA != CblasTrans && TransB != CblasTrans) { + kernel_sgemm = &kernel_sgemm_noTrans; + layerKernel = context.LayerKernel::SGEMM_NOTRANS; + sgemm_cl_kernel_ = sgemm_cl_noTrans_kernel_; + } else if (TransA == CblasTrans && TransB != CblasTrans) { + kernel_sgemm = &kernel_sgemm_transA; + layerKernel = context.LayerKernel::SGEMM_TRANSA; + sgemm_cl_kernel_ = sgemm_cl_transA_kernel_; + } else if (TransA != CblasTrans && TransB == CblasTrans) { + kernel_sgemm = &kernel_sgemm_transB; + layerKernel = context.LayerKernel::SGEMM_TRANSB; + sgemm_cl_kernel_ = sgemm_cl_transB_kernel_; + } else { + kernel_sgemm = &kernel_sgemm_transAB; + layerKernel = context.LayerKernel::SGEMM_TRANSAB; + sgemm_cl_kernel_ = sgemm_cl_transAB_kernel_; + } bool result = false; do { - result = context.clCreateKernel(sgemm_cl_kernel_, - context.LayerKernel::SGEMM, kernel_sgemm); + result = + context.clCreateKernel(sgemm_cl_kernel_, layerKernel, *kernel_sgemm); if (!result) { break; } + // sizes will be same for transpose size_t m_k_size = M * K * sizeof(float); size_t k_n_size = K * N * sizeof(float); size_t m_n_size = M * N * sizeof(float); @@ -265,37 +344,37 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M, break; } - result = kernel_sgemm.SetKernelArguments(0, &inputA, sizeof(cl_mem)); + result = kernel_sgemm->SetKernelArguments(0, &inputA, sizeof(cl_mem)); if (!result) { break; } - result = kernel_sgemm.SetKernelArguments(1, &inputB, sizeof(cl_mem)); + result = kernel_sgemm->SetKernelArguments(1, &inputB, sizeof(cl_mem)); if (!result) { break; } - result = kernel_sgemm.SetKernelArguments(2, &inOutC, sizeof(cl_mem)); + result = kernel_sgemm->SetKernelArguments(2, &inOutC, sizeof(cl_mem)); if (!result) { break; } - result = kernel_sgemm.SetKernelArguments(3, &K, sizeof(int)); + result = kernel_sgemm->SetKernelArguments(3, &K, sizeof(int)); if (!result) { break; } - result = kernel_sgemm.SetKernelArguments(4, &lda, sizeof(int)); + result = kernel_sgemm->SetKernelArguments(4, &lda, sizeof(int)); if (!result) { break; } - result = kernel_sgemm.SetKernelArguments(5, &ldb, sizeof(int)); + result = kernel_sgemm->SetKernelArguments(5, &ldb, sizeof(int)); if (!result) { break; } - result = kernel_sgemm.SetKernelArguments(6, &ldc, sizeof(int)); + result = kernel_sgemm->SetKernelArguments(6, &ldc, sizeof(int)); if (!result) { break; } @@ -304,7 +383,7 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M, const int work_group_size[3] = {32, 32, 1}; // test-value result = context.command_queue_inst_.DispatchCommand( - kernel_sgemm, work_groups_count, work_group_size); + *kernel_sgemm, work_groups_count, work_group_size); if (!result) { break; } diff --git a/nntrainer/tensor/cl_operations/blas_kernels.h b/nntrainer/tensor/cl_operations/blas_kernels.h index 3ae4ae97b3..6b118c68dd 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels.h +++ b/nntrainer/tensor/cl_operations/blas_kernels.h @@ -25,7 +25,10 @@ namespace nntrainer { * @brief declaring global kernel objects */ extern opencl::Kernel kernel_sgemv; -extern opencl::Kernel kernel_sgemm; +extern opencl::Kernel kernel_sgemm_noTrans; +extern opencl::Kernel kernel_sgemm_transAB; +extern opencl::Kernel kernel_sgemm_transA; +extern opencl::Kernel kernel_sgemm_transB; extern opencl::Kernel kernel_dot; extern opencl::Kernel kernel_addition; extern opencl::Kernel kernel_sscal; @@ -58,6 +61,8 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1, /** * @brief sgemm computation : Y = op(A)*op(B) + C, * where op(X) is one of X or X**T + * @param[in] transA CBLAS_TRANSPOSE + * @param[in] transB CBLAS_TRANSPOSE * @param[in] A float * for Matrix A * @param[in] B float * for Matrix B * @param[in] C float * for Matrix C @@ -69,9 +74,10 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1, * @param[in] ldc number of C's columns * @param[in] context RunLayerContext reference */ -void sgemm_cl(const float *A, const float *B, float *C, unsigned int M, - unsigned int N, unsigned int K, unsigned int lda, - unsigned int ldb, unsigned int ldc, RunLayerContext &context); +void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const float *A, + const float *B, float *C, unsigned int M, unsigned int N, + unsigned int K, unsigned int lda, unsigned int ldb, + unsigned int ldc, RunLayerContext &context); /** * @brief addition : sum of all input vectors @@ -98,7 +104,10 @@ void sscal_cl(float *X, const unsigned int N, const float alpha, * @brief declaring global fp16 kernel objects */ extern opencl::Kernel kernel_sgemv_fp16; -extern opencl::Kernel kernel_sgemm_fp16; +extern opencl::Kernel kernel_sgemm_noTrans_fp16; +extern opencl::Kernel kernel_sgemm_transAB_fp16; +extern opencl::Kernel kernel_sgemm_transA_fp16; +extern opencl::Kernel kernel_sgemm_transB_fp16; extern opencl::Kernel kernel_dot_fp16; extern opencl::Kernel kernel_addition_fp16; extern opencl::Kernel kernel_sscal_fp16; @@ -131,6 +140,8 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1, /** * @brief fp16 sgemm computation : Y = op(A)*op(B) + C, * where op(X) is one of X or X**T + * @param[in] transA CBLAS_TRANSPOSE + * @param[in] transB CBLAS_TRANSPOSE * @param[in] A fp16 * for Matrix A * @param[in] B fp16 * for Matrix B * @param[in] C fp16 * for Matrix C @@ -142,9 +153,10 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1, * @param[in] ldc number of C's columns * @param[in] context RunLayerContext reference */ -void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, - unsigned int N, unsigned int K, unsigned int lda, - unsigned int ldb, unsigned int ldc, RunLayerContext &context); +void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const __fp16 *A, + const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N, + unsigned int K, unsigned int lda, unsigned int ldb, + unsigned int ldc, RunLayerContext &context); /** * @brief fp16 addition : sum of all input vectors diff --git a/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp b/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp index 83f0d2136b..96c7ce9c90 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp @@ -41,11 +41,11 @@ std::string dot_cl_kernel_fp16_ = } })"; -std::string sgemm_cl_kernel_fp16_ = +std::string sgemm_cl_noTrans_kernel_fp16_ = R"( #pragma OPENCL EXTENSION cl_khr_fp16 : enable - __kernel void sgemm_cl_fp16(const __global half* A, const __global half* B, + __kernel void sgemm_cl_noTrans_fp16(const __global half* A, const __global half* B, __global half* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) { unsigned int m = get_global_id(0); @@ -60,6 +60,63 @@ std::string sgemm_cl_kernel_fp16_ = C[m * ldc + n] = c; })"; +std::string sgemm_cl_transA_kernel_fp16_ = + R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + + __kernel void sgemm_cl_transA_fp16(const __global half* A, const __global half* B, + __global half* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) { + + unsigned int m = get_global_id(0); + unsigned int n = get_global_id(1); + half c = 0.0f; + for (unsigned int k = 0; k < K; ++k) { + half a, b; + a = A[k * lda + m]; + b = B[k * ldb + n]; + c += a * b; + } + C[m * ldc + n] = c; + })"; + +std::string sgemm_cl_transB_kernel_fp16_ = + R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + + __kernel void sgemm_cl_transB_fp16(const __global half* A, const __global half* B, + __global half* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) { + + unsigned int m = get_global_id(0); + unsigned int n = get_global_id(1); + half c = 0.0f; + for (unsigned int k = 0; k < K; ++k) { + half a, b; + a = A[m * lda + k]; + b = B[n * ldb + k]; + c += a * b; + } + C[m * ldc + n] = c; + })"; + +std::string sgemm_cl_transAB_kernel_fp16_ = + R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + + __kernel void sgemm_cl_transAB_fp16(const __global half* A, const __global half* B, + __global half* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) { + + unsigned int m = get_global_id(0); + unsigned int n = get_global_id(1); + half c = 0.0f; + for (unsigned int k = 0; k < K; ++k) { + half a, b; + a = A[k * lda + m]; + b = B[n * ldb + k]; + c += a * b; + } + C[m * ldc + n] = c; + })"; + std::string addition_cl_kernel_fp16_ = R"( #pragma OPENCL EXTENSION cl_khr_fp16 : enable @@ -85,7 +142,10 @@ std::string sscal_cl_kernel_fp16_ = * @brief defining global kernel objects */ opencl::Kernel kernel_sgemv_fp16; -opencl::Kernel kernel_sgemm_fp16; +opencl::Kernel kernel_sgemm_transAB_fp16; +opencl::Kernel kernel_sgemm_transA_fp16; +opencl::Kernel kernel_sgemm_transB_fp16; +opencl::Kernel kernel_sgemm_noTrans_fp16; opencl::Kernel kernel_dot_fp16; opencl::Kernel kernel_addition_fp16; opencl::Kernel kernel_sscal_fp16; @@ -242,20 +302,43 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1, return cl_ret; } -void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, - unsigned int N, unsigned int K, unsigned int lda, - unsigned int ldb, unsigned int ldc, RunLayerContext &context) { +void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const __fp16 *A, + const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N, + unsigned int K, unsigned int lda, unsigned int ldb, + unsigned int ldc, RunLayerContext &context) { + + opencl::Kernel *kernel_sgemm_fp16 = nullptr; + RunLayerContext::LayerKernel layerKernel; + std::string sgemm_cl_kernel_fp16_; + + if (TransA != CblasTrans && TransB != CblasTrans) { + kernel_sgemm_fp16 = &kernel_sgemm_noTrans_fp16; + layerKernel = context.LayerKernel::SGEMM_NOTRANS_FP16; + sgemm_cl_kernel_fp16_ = sgemm_cl_noTrans_kernel_fp16_; + } else if (TransA == CblasTrans && TransB != CblasTrans) { + kernel_sgemm_fp16 = &kernel_sgemm_transA_fp16; + layerKernel = context.LayerKernel::SGEMM_TRANSA_FP16; + sgemm_cl_kernel_fp16_ = sgemm_cl_transA_kernel_fp16_; + } else if (TransA != CblasTrans && TransB == CblasTrans) { + kernel_sgemm_fp16 = &kernel_sgemm_transB_fp16; + layerKernel = context.LayerKernel::SGEMM_TRANSB_FP16; + sgemm_cl_kernel_fp16_ = sgemm_cl_transB_kernel_fp16_; + } else { + kernel_sgemm_fp16 = &kernel_sgemm_transAB_fp16; + layerKernel = context.LayerKernel::SGEMM_TRANSAB_FP16; + sgemm_cl_kernel_fp16_ = sgemm_cl_transAB_kernel_fp16_; + } bool result = false; do { - result = context.clCreateKernel(sgemm_cl_kernel_fp16_, - context.LayerKernel::SGEMM_FP16, - kernel_sgemm_fp16); + result = context.clCreateKernel(sgemm_cl_kernel_fp16_, layerKernel, + *kernel_sgemm_fp16); if (!result) { break; } + // sizes will be same for transpose size_t m_k_size = M * K * sizeof(cl_half); size_t k_n_size = K * N * sizeof(cl_half); size_t m_n_size = M * N * sizeof(cl_half); @@ -281,37 +364,37 @@ void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, break; } - result = kernel_sgemm_fp16.SetKernelArguments(0, &inputA, sizeof(cl_mem)); + result = kernel_sgemm_fp16->SetKernelArguments(0, &inputA, sizeof(cl_mem)); if (!result) { break; } - result = kernel_sgemm_fp16.SetKernelArguments(1, &inputB, sizeof(cl_mem)); + result = kernel_sgemm_fp16->SetKernelArguments(1, &inputB, sizeof(cl_mem)); if (!result) { break; } - result = kernel_sgemm_fp16.SetKernelArguments(2, &inOutC, sizeof(cl_mem)); + result = kernel_sgemm_fp16->SetKernelArguments(2, &inOutC, sizeof(cl_mem)); if (!result) { break; } - result = kernel_sgemm_fp16.SetKernelArguments(3, &K, sizeof(int)); + result = kernel_sgemm_fp16->SetKernelArguments(3, &K, sizeof(int)); if (!result) { break; } - result = kernel_sgemm_fp16.SetKernelArguments(4, &lda, sizeof(int)); + result = kernel_sgemm_fp16->SetKernelArguments(4, &lda, sizeof(int)); if (!result) { break; } - result = kernel_sgemm_fp16.SetKernelArguments(5, &ldb, sizeof(int)); + result = kernel_sgemm_fp16->SetKernelArguments(5, &ldb, sizeof(int)); if (!result) { break; } - result = kernel_sgemm_fp16.SetKernelArguments(6, &ldc, sizeof(int)); + result = kernel_sgemm_fp16->SetKernelArguments(6, &ldc, sizeof(int)); if (!result) { break; } @@ -320,7 +403,7 @@ void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, const int work_group_size[3] = {32, 32, 1}; // test-value result = context.command_queue_inst_.DispatchCommand( - kernel_sgemm_fp16, work_groups_count, work_group_size); + *kernel_sgemm_fp16, work_groups_count, work_group_size); if (!result) { break; } diff --git a/test/unittest/unittest_blas_kernels_cl.cpp b/test/unittest/unittest_blas_kernels_cl.cpp index cac5b9e964..d897d69e8d 100644 --- a/test/unittest/unittest_blas_kernels_cl.cpp +++ b/test/unittest/unittest_blas_kernels_cl.cpp @@ -44,7 +44,7 @@ TEST(blas_kernels, dotCL_sgemv) { int width = 768; int height_b = 768; - int width_b = 96000; + int width_b = 2048; bool transA = false; bool transB = false; @@ -94,7 +94,7 @@ TEST(blas_kernels, dotCL_sgemv_n) { int width = 768; int height_b = 768; - int width_b = 96000; + int width_b = 2048; bool transA = true; bool transB = false; @@ -166,6 +166,254 @@ TEST(nntrainer_Tensor, multiply_i) { EXPECT_IN_RANGE(cosSimNeon, 0.99, 1); } +TEST(nntrainer_Tensor, dot_gemm_50_768_1024_noTrans) { + /// @note GEMM : A X B = C + RunLayerContext rc = setUpGpuContext(); + + int batch = 1; + int channel = 1; + int height = 50; + int width = 768; + + int height_b = 768; + int width_b = 1024; + + bool transA = false; + bool transB = false; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16); + nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16); + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + + GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) + + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + nntrainer::Tensor C = dotCl(A_fp32, B_fp32, rc, transA, transB); + nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + + float mseErrorNeon = + mse(C.getData(), C_fp32.getData(), C.size()); + + double cosSimNeon = cosine_similarity( + C.getData(), C_fp32.getData(), C.size()); + + const float epsilon = 1e-3 * width; + + EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon); + EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); +} + +TEST(nntrainer_Tensor, dot_gemm_50_768_2048_transB) { + /// @note GEMM : A X B = C + RunLayerContext rc = setUpGpuContext(); + + int batch = 1; + int channel = 1; + int height = 50; + int width = 768; + + int height_b = 2048; + int width_b = 768; + + bool transA = false; + bool transB = true; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16); + nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16); + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + + GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) + + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + nntrainer::Tensor C = dotCl(A_fp32, B_fp32, rc, transA, transB); + nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + + float mseErrorNeon = + mse(C.getData(), C_fp32.getData(), C.size()); + + double cosSimNeon = cosine_similarity( + C.getData(), C_fp32.getData(), C.size()); + + const float epsilon = 1e-3 * width; + + EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon); + EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); +} + +TEST(nntrainer_Tensor, dot_gemm_50_768_1024_transA) { + /// @note GEMM : A X B = C + RunLayerContext rc = setUpGpuContext(); + + int batch = 1; + int channel = 1; + int height = 768; + int width = 50; + + int height_b = 768; + int width_b = 1024; + + bool transA = true; + bool transB = false; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16); + nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16); + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + + GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) + + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + nntrainer::Tensor C = dotCl(A_fp32, B_fp32, rc, transA, transB); + nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + + float mseErrorNeon = + mse(C.getData(), C_fp32.getData(), C.size()); + + double cosSimNeon = cosine_similarity( + C.getData(), C_fp32.getData(), C.size()); + + const float epsilon = 1e-3 * width; + + EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon); + EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); +} + +TEST(nntrainer_Tensor, dot_gemm_50_768_2048_transAB) { + /// @note GEMM : A X B = C + RunLayerContext rc = setUpGpuContext(); + + int batch = 1; + int channel = 1; + int height = 768; + int width = 50; + + int height_b = 2048; + int width_b = 768; + + bool transA = true; + bool transB = true; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16); + nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16); + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + + GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) + + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + nntrainer::Tensor C = dotCl(A_fp32, B_fp32, rc, transA, transB); + nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + + float mseErrorNeon = + mse(C.getData(), C_fp32.getData(), C.size()); + + double cosSimNeon = cosine_similarity( + C.getData(), C_fp32.getData(), C.size()); + + const float epsilon = 1e-3 * width; + + EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon); + EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); +} + GTEST_API_ int main(int argc, char **argv) { int result = -1;