From 9b1e33b4d9f689f014979030f8f85eb6f82b8c53 Mon Sep 17 00:00:00 2001 From: Debadri Samaddar Date: Tue, 11 Jun 2024 12:55:46 +0530 Subject: [PATCH] [blas/OpenCL] Added multiply OpenCL kernel and unit test Added sscal equivalent kernel and multiply function. Added unit test setup to test standalone kernels. Signed-off-by: Debadri Samaddar --- nntrainer/layers/layer_context.cpp | 4 + nntrainer/layers/layer_context.h | 2 + .../cl_operations/blas_kernel_interface.cpp | 25 +++ .../cl_operations/blas_kernel_interface.h | 20 ++ .../tensor/cl_operations/blas_kernels.cpp | 55 ++++++ nntrainer/tensor/cl_operations/blas_kernels.h | 22 +++ .../cl_operations/blas_kernels_fp16.cpp | 60 ++++++ test/jni/Android.mk | 17 ++ test/unittest/unittest_blas_kernels_cl.cpp | 186 ++++++++++++++++++ 9 files changed, 391 insertions(+) create mode 100644 test/unittest/unittest_blas_kernels_cl.cpp diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index 1a66aed3cd..7902fd8017 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -662,6 +662,10 @@ std::string RunLayerContext::getKernelName(LayerKernel layerKernel) { return "dot_cl_fp16"; case LayerKernel::SGEMM_FP16: return "sgemm_cl_fp16"; + case LayerKernel::SSCAL: + return "sscal_cl"; + case LayerKernel::SSCAL_FP16: + return "sscal_cl_fp16"; default: return ""; } diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index 43e9d8eaf8..79cba84860 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -836,6 +836,8 @@ class RunLayerContext { SGEMV_FP16 = 1 << 3, /**< placeholder for kernel name */ DOT_FP16 = 1 << 4, /**< placeholder for kernel name */ SGEMM_FP16 = 1 << 5, /**< placeholder for kernel name */ + SSCAL = 1 << 6, /**< placeholder for kernel name */ + SSCAL_FP16 = 1 << 7, /**< placeholder for kernel name */ }; /** diff --git a/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp b/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp index 852b482529..c0c98019d5 100644 --- a/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp @@ -30,6 +30,14 @@ void dotBatchedCl(Tensor const &input, Tensor const &m, Tensor &result, } } +Tensor dotCl(Tensor const &input, Tensor const &m, RunLayerContext &context, + bool trans, bool trans_m) { + Tensor output("", input.getFormat(), input.getDataType()); + dotCl(input, m, output, context, trans, trans_m); + + return output; +} + void dotCl(Tensor const &input, Tensor const &m, Tensor &result, RunLayerContext &context, bool trans, bool trans_m) { unsigned int dim1, dim2, mdim1, mdim2; @@ -186,4 +194,21 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result, } } +void multiplyCl(Tensor &input, float const &value, RunLayerContext &context) { + if (input.getDataType() == ml::train::TensorDim::DataType::FP32) { + float *data = input.getData(); + unsigned int len = input.size(); + + sscal_cl(data, len, value, context); + } else if (input.getDataType() == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + _FP16 *data = input.getData<_FP16>(); + unsigned int len = input.size(); + sscal_cl(data, len, value, context); +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif + } +} + } // namespace nntrainer diff --git a/nntrainer/tensor/cl_operations/blas_kernel_interface.h b/nntrainer/tensor/cl_operations/blas_kernel_interface.h index 41fdfd9242..20c7bdac07 100644 --- a/nntrainer/tensor/cl_operations/blas_kernel_interface.h +++ b/nntrainer/tensor/cl_operations/blas_kernel_interface.h @@ -19,6 +19,18 @@ namespace nntrainer { +/** + * @brief Process data and dimensions for OpenCL dot operation + * @param[in] input Tensor + * @param[in] m Tensor + * @param[in] result Tensor + * @param[in] RunLayerContext reference + * @param[in] trans bool + * @param[in] trans_m bool + */ +Tensor dotCl(Tensor const &input, Tensor const &m, RunLayerContext &context, + bool trans = false, bool trans_m = false); + /** * @brief Process data and dimensions for OpenCL dot operation * @param[in] input Tensor @@ -44,5 +56,13 @@ void dotBatchedCl(Tensor const &input, Tensor const &m, Tensor &result, RunLayerContext &context, bool trans = false, bool trans_m = false); +/** + * @brief Multiply value element by element immediately + * @param[in] input Tensor + * @param[in] value multiplier + * @param[in] RunLayerContext reference + */ +void multiplyCl(Tensor &input, float const &value, RunLayerContext &context); + } // namespace nntrainer #endif /* __BLAS_KERNEL_INTERFACE_H__ */ diff --git a/nntrainer/tensor/cl_operations/blas_kernels.cpp b/nntrainer/tensor/cl_operations/blas_kernels.cpp index 4c54a0b262..bb6a89cceb 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernels.cpp @@ -51,12 +51,20 @@ std::string sgemm_cl_kernel_ = C[m * ldc + n] = c; })"; +std::string sscal_cl_kernel_ = + R"(__kernel void sscal_cl(__global float* X, const float alpha) { + + unsigned int i = get_global_id(0); + X[i] *= alpha; + })"; + /** * @brief defining global kernel objects */ opencl::Kernel kernel_sgemv; opencl::Kernel kernel_sgemm; opencl::Kernel kernel_dot; +opencl::Kernel kernel_sscal; void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata, unsigned int dim1, unsigned int dim2, unsigned int lda, @@ -299,4 +307,51 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M, } while (false); } +void sscal_cl(float *X, const unsigned int N, const float alpha, + RunLayerContext &context) { + bool result = false; + + do { + result = context.clCreateKernel(sscal_cl_kernel_, + context.LayerKernel::SSCAL, kernel_sscal); + if (!result) { + break; + } + + size_t x_size = N * sizeof(float); + + opencl::Buffer inputX(context.context_inst_, x_size, false, nullptr); + + result = inputX.WriteData(context.command_queue_inst_, X); + if (!result) { + break; + } + + result = kernel_sscal.SetKernelArguments(0, &inputX, sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_sscal.SetKernelArguments(1, &alpha, sizeof(float)); + if (!result) { + break; + } + + const int work_groups_count[3] = {(int)N, 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = context.command_queue_inst_.DispatchCommand( + kernel_sscal, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inputX.ReadData(context.command_queue_inst_, X); + if (!result) { + break; + } + + } while (false); +} + } // namespace nntrainer diff --git a/nntrainer/tensor/cl_operations/blas_kernels.h b/nntrainer/tensor/cl_operations/blas_kernels.h index d9f06490b0..b177f4a6c0 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels.h +++ b/nntrainer/tensor/cl_operations/blas_kernels.h @@ -27,6 +27,7 @@ namespace nntrainer { extern opencl::Kernel kernel_sgemv; extern opencl::Kernel kernel_sgemm; extern opencl::Kernel kernel_dot; +extern opencl::Kernel kernel_sscal; /** * @brief sgemv computation : Y = A*X + Y @@ -71,6 +72,16 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M, unsigned int N, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc, RunLayerContext &context); +/** + * @brief sscal value element by element immediately + * @param[in] X float * input + * @param[in] N unsigned int number of elements + * @param[in] alpha float multiplier + * @param[in] context RunLayerContext reference + */ +void sscal_cl(float *X, const unsigned int N, const float alpha, + RunLayerContext &context); + #ifdef ENABLE_FP16 /** * @brief declaring global fp16 kernel objects @@ -78,6 +89,7 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M, extern opencl::Kernel kernel_sgemv_fp16; extern opencl::Kernel kernel_sgemm_fp16; extern opencl::Kernel kernel_dot_fp16; +extern opencl::Kernel kernel_sscal_fp16; /** * @brief fp16 sgemv computation : Y = A*X + Y @@ -121,6 +133,16 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1, void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc, RunLayerContext &context); + +/** + * @brief fp16 sscal value element by element immediately + * @param[in] X float * input + * @param[in] N unsigned int number of elements + * @param[in] alpha float multiplier + * @param[in] context RunLayerContext reference + */ +void sscal_cl(__fp16 *X, const unsigned int N, const float alpha, + RunLayerContext &context); #endif } // namespace nntrainer diff --git a/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp b/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp index 8948f0dc5c..288f2a0046 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp @@ -60,12 +60,23 @@ std::string sgemm_cl_kernel_fp16_ = C[m * ldc + n] = c; })"; +std::string sscal_cl_kernel_fp16_ = + R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + + __kernel void sscal_cl_fp16(__global half* X, const float alpha) { + + unsigned int i = get_global_id(0); + X[i] *= alpha; + })"; + /** * @brief defining global kernel objects */ opencl::Kernel kernel_sgemv_fp16; opencl::Kernel kernel_sgemm_fp16; opencl::Kernel kernel_dot_fp16; +opencl::Kernel kernel_sscal_fp16; void sgemv_cl(const __fp16 *matAdata, const __fp16 *vecXdata, __fp16 *vecYdata, unsigned int dim1, unsigned int dim2, unsigned int lda, @@ -309,4 +320,53 @@ void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, } while (false); } + +void sscal_cl(__fp16 *X, const unsigned int N, const float alpha, + RunLayerContext &context) { + bool result = false; + + do { + result = context.clCreateKernel(sscal_cl_kernel_fp16_, + context.LayerKernel::SSCAL_FP16, + kernel_sscal_fp16); + if (!result) { + break; + } + + size_t x_size = N * sizeof(cl_half); + + opencl::Buffer inputX(context.context_inst_, x_size, false, nullptr); + + result = inputX.WriteData(context.command_queue_inst_, X); + if (!result) { + break; + } + + result = kernel_sscal_fp16.SetKernelArguments(0, &inputX, sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_sscal_fp16.SetKernelArguments(1, &alpha, sizeof(float)); + if (!result) { + break; + } + + const int work_groups_count[3] = {(int)N, 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = context.command_queue_inst_.DispatchCommand( + kernel_sscal_fp16, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inputX.ReadData(context.command_queue_inst_, X); + if (!result) { + break; + } + + } while (false); +} + } // namespace nntrainer diff --git a/test/jni/Android.mk b/test/jni/Android.mk index 978e98bd67..60eee6e60d 100644 --- a/test/jni/Android.mk +++ b/test/jni/Android.mk @@ -22,6 +22,7 @@ NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer \ $(NNTRAINER_ROOT)/nntrainer/opencl \ $(NNTRAINER_ROOT)/nntrainer/optimizers \ $(NNTRAINER_ROOT)/nntrainer/tensor \ + $(NNTRAINER_ROOT)/nntrainer/tensor/cl_operations \ $(NNTRAINER_ROOT)/nntrainer/utils \ $(NNTRAINER_ROOT)/api \ $(NNTRAINER_ROOT)/api/ccapi/include \ @@ -477,6 +478,22 @@ LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer LOCAL_STATIC_LIBRARIES := googletest_main test_util include $(BUILD_EXECUTABLE) +include $(CLEAR_VARS) + +LOCAL_MODULE := unittest_blas_kernels_cl +LOCAL_CFLAGS := -Igoogletest/include -I../include -I../unittest/layers -I../../nntrainer/layers/loss -pthread -fexceptions -fopenmp -static-openmp -DMIN_CPP_VERSION=201703L -DNNTR_NUM_THREADS=1 -D__LOGGING__=1 -DENABLE_TEST=1 -DREDUCE_TOLERANCE=1 -march=armv8.2-a+fp16 -mfpu=neon-fp16 -mfloat-abi=softfp -O3 -frtti -DNDK_BUILD=1 -DENABLE_FP16=1 -DENABLE_OPENCL=1 +LOCAL_CXXFLAGS += -std=c++17 -frtti -fexceptions +LOCAL_LDLIBS := -llog -landroid -fopenmp -static-openmp + +LOCAL_SRC_FILES := \ + ../unittest/unittest_blas_kernels_cl.cpp + +LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES) + +LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer +LOCAL_STATIC_LIBRARIES := googletest_main test_util +include $(BUILD_EXECUTABLE) + # unittest_ccapi include $(CLEAR_VARS) diff --git a/test/unittest/unittest_blas_kernels_cl.cpp b/test/unittest/unittest_blas_kernels_cl.cpp new file mode 100644 index 0000000000..cac5b9e964 --- /dev/null +++ b/test/unittest/unittest_blas_kernels_cl.cpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Debadri Samaddar + * + * @file unittest_blas_kernels_cl.cpp + * @date 6 June 2024 + * @brief Test setup for blas OpenCL kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Debadri Samaddar + * @bug No known bugs except for NYI items + */ + +#include +#include +#include + +#include "nntrainer_test_util.h" +#include "util_func.h" +#include +#include +#include +#include + +#define EXPECT_IN_RANGE(VAL, MIN, MAX) \ + EXPECT_GE((VAL), (MIN)); \ + EXPECT_LE((VAL), (MAX)) + +using namespace nntrainer; + +static RunLayerContext setUpGpuContext() { + + auto &ac = nntrainer::ClContext::Global(); + auto rc = RunLayerContext(); + + return rc; +} + +TEST(blas_kernels, dotCL_sgemv) { + RunLayerContext rc = setUpGpuContext(); + + int batch = 1; + int channel = 1; + int height = 1; + int width = 768; + + int height_b = 768; + int width_b = 96000; + + bool transA = false; + bool transB = false; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + nntrainer::Tensor C = dotCl(A_fp32, B_fp32, rc, transA, transB); + nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + + float mseErrorNeon = + mse(C.getData(), C_fp32.getData(), C.size()); + + double cosSimNeon = cosine_similarity( + C.getData(), C_fp32.getData(), C.size()); + + const float epsilon = 1e-3 * width; + + EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon); + EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); +} + +TEST(blas_kernels, dotCL_sgemv_n) { + RunLayerContext rc = setUpGpuContext(); + + int batch = 1; + int channel = 1; + int height = 1; + int width = 768; + + int height_b = 768; + int width_b = 96000; + + bool transA = true; + bool transB = false; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + EXPECT_THROW(dotCl(A_fp32, B_fp32, rc, transA, transB), std::runtime_error); +} + +TEST(nntrainer_Tensor, multiply_i) { + RunLayerContext rc = setUpGpuContext(); + + int batch = 1; + int channel = 1; + int height = 2; + int width = 11; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor input(batch, channel, height, width, t_type_nchw_fp16); + nntrainer::Tensor input_fp32(batch, channel, height, width, t_type_nchw_fp32); + + const float alpha = 1e-5; + const float epsilon = 1e-4; + + GEN_TEST_INPUT(input, i * (batch * height * channel) * alpha + + j * (batch * height) * alpha + k * (width)*alpha + l + + 1); + GEN_TEST_INPUT(input_fp32, i * (batch * height * channel) * alpha + + j * (batch * height) * alpha + + k * (width)*alpha + l + 1); + + // fp16 + multiplyCl(input, 0.1, rc); + + // fp32 + multiplyCl(input_fp32, 0.1, rc); + + float mseErrorNeon = mse<__fp16>(input.getData<__fp16>(), + input_fp32.getData(), input.size()); + + double cosSimNeon = cosine_similarity<__fp16>( + input.getData<__fp16>(), input_fp32.getData(), input.size()); + + EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon); + EXPECT_IN_RANGE(cosSimNeon, 0.99, 1); +} + +GTEST_API_ int main(int argc, char **argv) { + int result = -1; + + try { + testing::InitGoogleTest(&argc, argv); + } catch (...) { + std::cerr << "Error during InitGoogleTest" << std::endl; + return 0; + } + + try { + result = RUN_ALL_TESTS(); + } catch (...) { + std::cerr << "Error during RUN_ALL_TESTS()" << std::endl; + } + + return result; +}