Skip to content

Commit

Permalink
[blas/OpenCL] Added multiply OpenCL kernel and unit test
Browse files Browse the repository at this point in the history
Added sscal equivalent kernel and multiply function.
Added unit test setup to test standalone kernels.

Signed-off-by: Debadri Samaddar <[email protected]>
  • Loading branch information
s-debadri committed Jun 11, 2024
1 parent a1f2f8a commit 9b1e33b
Show file tree
Hide file tree
Showing 9 changed files with 391 additions and 0 deletions.
4 changes: 4 additions & 0 deletions nntrainer/layers/layer_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,10 @@ std::string RunLayerContext::getKernelName(LayerKernel layerKernel) {
return "dot_cl_fp16";
case LayerKernel::SGEMM_FP16:
return "sgemm_cl_fp16";
case LayerKernel::SSCAL:
return "sscal_cl";
case LayerKernel::SSCAL_FP16:
return "sscal_cl_fp16";
default:
return "";
}
Expand Down
2 changes: 2 additions & 0 deletions nntrainer/layers/layer_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,8 @@ class RunLayerContext {
SGEMV_FP16 = 1 << 3, /**< placeholder for kernel name */
DOT_FP16 = 1 << 4, /**< placeholder for kernel name */
SGEMM_FP16 = 1 << 5, /**< placeholder for kernel name */
SSCAL = 1 << 6, /**< placeholder for kernel name */
SSCAL_FP16 = 1 << 7, /**< placeholder for kernel name */
};

/**
Expand Down
25 changes: 25 additions & 0 deletions nntrainer/tensor/cl_operations/blas_kernel_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ void dotBatchedCl(Tensor const &input, Tensor const &m, Tensor &result,
}
}

Tensor dotCl(Tensor const &input, Tensor const &m, RunLayerContext &context,
bool trans, bool trans_m) {
Tensor output("", input.getFormat(), input.getDataType());
dotCl(input, m, output, context, trans, trans_m);

return output;
}

void dotCl(Tensor const &input, Tensor const &m, Tensor &result,
RunLayerContext &context, bool trans, bool trans_m) {
unsigned int dim1, dim2, mdim1, mdim2;
Expand Down Expand Up @@ -186,4 +194,21 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result,
}
}

void multiplyCl(Tensor &input, float const &value, RunLayerContext &context) {
if (input.getDataType() == ml::train::TensorDim::DataType::FP32) {
float *data = input.getData<float>();
unsigned int len = input.size();

sscal_cl(data, len, value, context);
} else if (input.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
_FP16 *data = input.getData<_FP16>();
unsigned int len = input.size();
sscal_cl(data, len, value, context);
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
}
}

} // namespace nntrainer
20 changes: 20 additions & 0 deletions nntrainer/tensor/cl_operations/blas_kernel_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@

namespace nntrainer {

/**
* @brief Process data and dimensions for OpenCL dot operation
* @param[in] input Tensor
* @param[in] m Tensor
* @param[in] result Tensor
* @param[in] RunLayerContext reference
* @param[in] trans bool
* @param[in] trans_m bool
*/
Tensor dotCl(Tensor const &input, Tensor const &m, RunLayerContext &context,
bool trans = false, bool trans_m = false);

/**
* @brief Process data and dimensions for OpenCL dot operation
* @param[in] input Tensor
Expand All @@ -44,5 +56,13 @@ void dotBatchedCl(Tensor const &input, Tensor const &m, Tensor &result,
RunLayerContext &context, bool trans = false,
bool trans_m = false);

/**
* @brief Multiply value element by element immediately
* @param[in] input Tensor
* @param[in] value multiplier
* @param[in] RunLayerContext reference
*/
void multiplyCl(Tensor &input, float const &value, RunLayerContext &context);

} // namespace nntrainer
#endif /* __BLAS_KERNEL_INTERFACE_H__ */
55 changes: 55 additions & 0 deletions nntrainer/tensor/cl_operations/blas_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,20 @@ std::string sgemm_cl_kernel_ =
C[m * ldc + n] = c;
})";

std::string sscal_cl_kernel_ =
R"(__kernel void sscal_cl(__global float* X, const float alpha) {
unsigned int i = get_global_id(0);
X[i] *= alpha;
})";

/**
* @brief defining global kernel objects
*/
opencl::Kernel kernel_sgemv;
opencl::Kernel kernel_sgemm;
opencl::Kernel kernel_dot;
opencl::Kernel kernel_sscal;

void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata,
unsigned int dim1, unsigned int dim2, unsigned int lda,
Expand Down Expand Up @@ -299,4 +307,51 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M,
} while (false);
}

void sscal_cl(float *X, const unsigned int N, const float alpha,
RunLayerContext &context) {
bool result = false;

do {
result = context.clCreateKernel(sscal_cl_kernel_,
context.LayerKernel::SSCAL, kernel_sscal);
if (!result) {
break;
}

size_t x_size = N * sizeof(float);

opencl::Buffer inputX(context.context_inst_, x_size, false, nullptr);

result = inputX.WriteData(context.command_queue_inst_, X);
if (!result) {
break;
}

result = kernel_sscal.SetKernelArguments(0, &inputX, sizeof(cl_mem));
if (!result) {
break;
}

result = kernel_sscal.SetKernelArguments(1, &alpha, sizeof(float));
if (!result) {
break;
}

const int work_groups_count[3] = {(int)N, 1, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value

result = context.command_queue_inst_.DispatchCommand(
kernel_sscal, work_groups_count, work_group_size);
if (!result) {
break;
}

result = inputX.ReadData(context.command_queue_inst_, X);
if (!result) {
break;
}

} while (false);
}

} // namespace nntrainer
22 changes: 22 additions & 0 deletions nntrainer/tensor/cl_operations/blas_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace nntrainer {
extern opencl::Kernel kernel_sgemv;
extern opencl::Kernel kernel_sgemm;
extern opencl::Kernel kernel_dot;
extern opencl::Kernel kernel_sscal;

/**
* @brief sgemv computation : Y = A*X + Y
Expand Down Expand Up @@ -71,13 +72,24 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M,
unsigned int N, unsigned int K, unsigned int lda,
unsigned int ldb, unsigned int ldc, RunLayerContext &context);

/**
* @brief sscal value element by element immediately
* @param[in] X float * input
* @param[in] N unsigned int number of elements
* @param[in] alpha float multiplier
* @param[in] context RunLayerContext reference
*/
void sscal_cl(float *X, const unsigned int N, const float alpha,
RunLayerContext &context);

#ifdef ENABLE_FP16
/**
* @brief declaring global fp16 kernel objects
*/
extern opencl::Kernel kernel_sgemv_fp16;
extern opencl::Kernel kernel_sgemm_fp16;
extern opencl::Kernel kernel_dot_fp16;
extern opencl::Kernel kernel_sscal_fp16;

/**
* @brief fp16 sgemv computation : Y = A*X + Y
Expand Down Expand Up @@ -121,6 +133,16 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1,
void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
unsigned int N, unsigned int K, unsigned int lda,
unsigned int ldb, unsigned int ldc, RunLayerContext &context);

/**
* @brief fp16 sscal value element by element immediately
* @param[in] X float * input
* @param[in] N unsigned int number of elements
* @param[in] alpha float multiplier
* @param[in] context RunLayerContext reference
*/
void sscal_cl(__fp16 *X, const unsigned int N, const float alpha,
RunLayerContext &context);
#endif

} // namespace nntrainer
Expand Down
60 changes: 60 additions & 0 deletions nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,23 @@ std::string sgemm_cl_kernel_fp16_ =
C[m * ldc + n] = c;
})";

std::string sscal_cl_kernel_fp16_ =
R"(
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void sscal_cl_fp16(__global half* X, const float alpha) {
unsigned int i = get_global_id(0);
X[i] *= alpha;
})";

/**
* @brief defining global kernel objects
*/
opencl::Kernel kernel_sgemv_fp16;
opencl::Kernel kernel_sgemm_fp16;
opencl::Kernel kernel_dot_fp16;
opencl::Kernel kernel_sscal_fp16;

void sgemv_cl(const __fp16 *matAdata, const __fp16 *vecXdata, __fp16 *vecYdata,
unsigned int dim1, unsigned int dim2, unsigned int lda,
Expand Down Expand Up @@ -309,4 +320,53 @@ void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,

} while (false);
}

void sscal_cl(__fp16 *X, const unsigned int N, const float alpha,
RunLayerContext &context) {
bool result = false;

do {
result = context.clCreateKernel(sscal_cl_kernel_fp16_,
context.LayerKernel::SSCAL_FP16,
kernel_sscal_fp16);
if (!result) {
break;
}

size_t x_size = N * sizeof(cl_half);

opencl::Buffer inputX(context.context_inst_, x_size, false, nullptr);

result = inputX.WriteData(context.command_queue_inst_, X);
if (!result) {
break;
}

result = kernel_sscal_fp16.SetKernelArguments(0, &inputX, sizeof(cl_mem));
if (!result) {
break;
}

result = kernel_sscal_fp16.SetKernelArguments(1, &alpha, sizeof(float));
if (!result) {
break;
}

const int work_groups_count[3] = {(int)N, 1, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value

result = context.command_queue_inst_.DispatchCommand(
kernel_sscal_fp16, work_groups_count, work_group_size);
if (!result) {
break;
}

result = inputX.ReadData(context.command_queue_inst_, X);
if (!result) {
break;
}

} while (false);
}

} // namespace nntrainer
17 changes: 17 additions & 0 deletions test/jni/Android.mk
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer \
$(NNTRAINER_ROOT)/nntrainer/opencl \
$(NNTRAINER_ROOT)/nntrainer/optimizers \
$(NNTRAINER_ROOT)/nntrainer/tensor \
$(NNTRAINER_ROOT)/nntrainer/tensor/cl_operations \
$(NNTRAINER_ROOT)/nntrainer/utils \
$(NNTRAINER_ROOT)/api \
$(NNTRAINER_ROOT)/api/ccapi/include \
Expand Down Expand Up @@ -477,6 +478,22 @@ LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer
LOCAL_STATIC_LIBRARIES := googletest_main test_util
include $(BUILD_EXECUTABLE)

include $(CLEAR_VARS)

LOCAL_MODULE := unittest_blas_kernels_cl
LOCAL_CFLAGS := -Igoogletest/include -I../include -I../unittest/layers -I../../nntrainer/layers/loss -pthread -fexceptions -fopenmp -static-openmp -DMIN_CPP_VERSION=201703L -DNNTR_NUM_THREADS=1 -D__LOGGING__=1 -DENABLE_TEST=1 -DREDUCE_TOLERANCE=1 -march=armv8.2-a+fp16 -mfpu=neon-fp16 -mfloat-abi=softfp -O3 -frtti -DNDK_BUILD=1 -DENABLE_FP16=1 -DENABLE_OPENCL=1
LOCAL_CXXFLAGS += -std=c++17 -frtti -fexceptions
LOCAL_LDLIBS := -llog -landroid -fopenmp -static-openmp

LOCAL_SRC_FILES := \
../unittest/unittest_blas_kernels_cl.cpp

LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES)

LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer
LOCAL_STATIC_LIBRARIES := googletest_main test_util
include $(BUILD_EXECUTABLE)

# unittest_ccapi
include $(CLEAR_VARS)

Expand Down
Loading

0 comments on commit 9b1e33b

Please sign in to comment.