diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index a492cd9f89..705dec462a 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -108,7 +108,9 @@ enum LayerType { LAYER_UPSAMPLE2D, /**< Upsample 2D Layer type */ LAYER_RMSNORM = ML_TRAIN_LAYER_TYPE_RMSNORM, /** #include #include +#include #include #include #include @@ -56,6 +57,10 @@ static void add_default_object(ClContext &cc) { cc.registerFactory(nntrainer::createLayer, TransposeLayerCl::type, ml::train::LayerType::LAYER_TRANSPOSE); + + cc.registerFactory(nntrainer::createLayer, + FullyConnectedRMSNormLayerCl::type, + ml::train::LayerType::LAYER_FUSED_FC_RMS); } static void registerer(ClContext &cc) noexcept { @@ -143,6 +148,7 @@ void ClContext::initBlasClKernels() { registerClKernel(sgemm_cl_transAB_kernel_, "sgemm_cl_transAB"); registerClKernel(addition_cl_kernel_, "addition_cl"); registerClKernel(sscal_cl_kernel_, "sscal_cl"); + registerClKernel(rmsnorm_cl_kernel_new, "rmsnorm_cl"); #ifdef ENABLE_FP16 registerClKernel(sgemv_cl_kernel_fp16_, "sgemv_cl_fp16"); @@ -154,6 +160,8 @@ void ClContext::initBlasClKernels() { registerClKernel(sgemm_cl_transAB_kernel_fp16_, "sgemm_cl_transAB_fp16"); registerClKernel(addition_cl_kernel_fp16_, "addition_cl_fp16"); registerClKernel(sscal_cl_kernel_fp16_, "sscal_cl_fp16"); + registerClKernel(rmsnorm_cl_kernel_fp16_new, "rmsnorm_cl_fp16"); + #endif blas_kernels_initialized = true; } diff --git a/nntrainer/layers/cl_layers/fused_fc_norm_cl.cpp b/nntrainer/layers/cl_layers/fused_fc_norm_cl.cpp new file mode 100644 index 0000000000..26e9290566 --- /dev/null +++ b/nntrainer/layers/cl_layers/fused_fc_norm_cl.cpp @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Debadri Samaddar + * + * @file fused_fc_norm_cl.cpp + * @date 7 May 2024 + * @brief This is Fully Connected Layer Class for Neural Network with OpenCl + * implementation + * @see https://github.com/nnstreamer/nntrainer + * @author Debadri Samaddar + * @bug No known bugs except for NYI items + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nntrainer { + +static constexpr size_t SINGLE_INOUT_IDX = 0; + +enum FC_RMSParams { weight, bias, gamma }; +// enum FCParams { weight, bias }; +// enum RMSParams { gamma }; + +// , fc_rms_props(props::Unit(), props::FUSED_FC_RMS_NORM_GAMMA_INIT_GPU(), +// props::Epsilon()) + +FullyConnectedRMSNormLayerCl::FullyConnectedRMSNormLayerCl() : + LayerImpl(), + fc_rms_props(props::Unit(), props::FUSED_FC_RMS_NORM_GAMMA_INIT_GPU(), + props::Epsilon()) { + weight_idx.fill(std::numeric_limits::max()); +} + +void FullyConnectedRMSNormLayerCl::finalize(InitLayerContext &context) { + auto &weight_regularizer = + std::get(*layer_impl_props); + auto &weight_regularizer_constant = + std::get(*layer_impl_props); + auto &weight_initializer = + std::get(*layer_impl_props); + auto &weight_decay = std::get(*layer_impl_props); + auto &bias_decay = std::get(*layer_impl_props); + auto &bias_initializer = std::get(*layer_impl_props); + auto &disable_bias = std::get(*layer_impl_props); + + auto unit = std::get(fc_rms_props).get(); + + NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) + << "Fully connected layer takes only one input"; + + std::vector output_dims(1); + + /// @todo fc actaully supports multidimensions. EffDimFlag shouldn't be fixed + /// like this. + context.setEffDimFlagInputDimension(0, 0b1001); + context.setDynDimFlagInputDimension(0, 0b1000); + + bool is_nchw = (context.getFormat() == Tformat::NCHW); + /** set output dimensions */ + auto const &in_dim = context.getInputDimensions()[0]; + output_dims[0] = in_dim; + is_nchw ? output_dims[0].width(unit) : output_dims[0].channel(unit); + + output_dims[0].setTensorType( + {context.getFormat(), context.getActivationDataType()}); + + context.setOutputDimensions(output_dims); + + /** set weight specifications */ + // @todo : This NCHW format setting is just temporal, it needs to be set by + // global configuration + TensorDim bias_dim( + 1, is_nchw ? 1 : unit, 1, is_nchw ? unit : 1, + TensorDim::TensorType(context.getFormat(), context.getWeightDataType()), + is_nchw ? 0b0001 : 0b0100); + + TensorDim weight_dim( + 1, is_nchw ? 1 : unit, is_nchw ? in_dim.width() : 1, + is_nchw ? unit : in_dim.channel(), + TensorDim::TensorType(context.getFormat(), context.getWeightDataType()), + is_nchw ? 0b0011 : 0b0101); + + weight_idx[FC_RMSParams::weight] = context.requestWeight( + weight_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, weight_decay, "weight", true); + + if (disable_bias.empty() || disable_bias.get() == false) { + weight_idx[FC_RMSParams::bias] = + context.requestWeight(bias_dim, bias_initializer, WeightRegularizer::NONE, + 1.0f, bias_decay, "bias", true); + } + + // for RMS layer, size of output already set for fc, line 70 + auto &rmsparams_gamma = + std::get(fc_rms_props); + + TensorDim gamma_dim( + 1, 1, 1, output_dims[0].width(), + TensorDim::TensorType(context.getFormat(), context.getWeightDataType())); + weight_idx[FC_RMSParams::gamma] = + context.requestWeight(gamma_dim, rmsparams_gamma, WeightRegularizer::NONE, + 1.0f, 0.0f, "gamma", false); +} + +// TO-DO +///////////////////////////////////////////////////////////////////////// +// fc +void FullyConnectedRMSNormLayerCl::exportTo( + Exporter &exporter, const ml::train::ExportMethods &method) const { + LayerImpl::exportTo(exporter, method); + exporter.saveResult(fc_rms_props, method, this); +} + +void FullyConnectedRMSNormLayerCl::setProperty( + const std::vector &values) { + auto remain_props = loadProperties(values, fc_rms_props); + LayerImpl::setProperty(remain_props); +} + +void FullyConnectedRMSNormLayerCl::forwarding(RunLayerContext &context, + bool training) { + + // for fc layer + Tensor &weight = context.getWeight(weight_idx[FC_RMSParams::weight]); + Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); + + // for rms + Tensor &gamma = context.getWeight(weight_idx[FC_RMSParams::gamma]); + auto &epsilon = std::get(fc_rms_props).get(); + + auto disable_bias = std::get(*layer_impl_props); + bool disable_bias_value = disable_bias.empty() || disable_bias.get() == false; + const Tensor &bias = context.getWeight(weight_idx[FC_RMSParams::bias]); + // printf("\n*************************************************************************************************************************************\n"); + // printf("Bias value : %s\n", disable_bias_value ? "true" : "false"); + // printf("\nInput Tensor Batch: %u, Channel: %u, Height: %u, Width: %u\n", + // input_.batch(), input_.channel(), input_.height(), input_.width()); for + // (unsigned int i = 0; i < input_.size(); ++i) { + // printf("Element %u -> %f\n", i, *(input_.getData() + i)); + // } + // printf("\nWeight Tensor Batch: %u, Channel: %u, Height: %u, Width: %u\n", + // weight.batch(), weight.channel(), weight.height(), weight.width()); + // printf("\nHidden Tensor Batch: %u, Channel: %u, Height: %u, Width: %u\n", + // hidden_.batch(), hidden_.channel(), hidden_.height(), hidden_.width()); + // printf("\nGamma Tensor Batch: %u, Channel: %u, Height: %u, Width: %u\n", + // gamma.batch(), gamma.channel(), gamma.height(), gamma.width()); + // printf("\nEpsilon value : %f\n", epsilon); + // printf("\n-----------------------------------------starting with fusion + // process from layer side-----------------------------------------------\n"); + + fusedProcess(input_, weight, hidden_, bias, disable_bias_value, gamma, + epsilon); +} + +// TO-DO +////// need to implement the incremental forwarding +void FullyConnectedRMSNormLayerCl::incremental_forwarding( + RunLayerContext &context, unsigned int from, unsigned int to, bool training) { + Tensor w; + Tensor &weight = w; + context.getWeight(weight, weight_idx[FC_RMSParams::weight]); + + Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); + Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + + // rms + Tensor &gamma = context.getWeight(weight_idx[FC_RMSParams::gamma]); + + TensorDim input_dim = input_.getDim(); + TensorDim hidden_dim = hidden_.getDim(); + + TensorDim input_step_dim = input_dim; + TensorDim hidden_step_dim = hidden_dim; + + if (from) { + NNTR_THROW_IF(to - from != 1, std::invalid_argument) + << "incremental step size is not 1"; + from = 0; + to = 1; + } + + input_step_dim.height(to - from); + hidden_step_dim.height(to - from); + + // @todo: set reset stride as false. This implementation only works when + // batch size is 1 + Tensor input_step = input_.getSharedDataTensor(input_step_dim, 0, true); + Tensor hidden_step = hidden_.getSharedDataTensor(hidden_step_dim, 0, true); + + auto &epsilon = std::get(fc_rms_props).get(); + + auto disable_bias = std::get(*layer_impl_props); + bool disable_bias_value = disable_bias.empty() || disable_bias.get() == false; + Tensor &bias = context.getWeight(weight_idx[FC_RMSParams::bias]); + + fusedProcess(input_step, weight, hidden_step, bias, disable_bias_value, gamma, + epsilon); +} + +void FullyConnectedRMSNormLayerCl::calcDerivative(RunLayerContext &context) { + Tensor &weight = context.getWeight(weight_idx[FC_RMSParams::weight]); + + const Tensor &derivative_ = context.getIncomingDerivative(SINGLE_INOUT_IDX); + Tensor &ret_ = context.getOutgoingDerivative(SINGLE_INOUT_IDX); + + ret_.dot_deriv_wrt_1(weight, derivative_, false, false); +} + +void FullyConnectedRMSNormLayerCl::calcGradient(RunLayerContext &context) { + Tensor &djdw = context.getWeightGrad(weight_idx[FC_RMSParams::weight]); + + const Tensor &derivative_ = context.getIncomingDerivative(SINGLE_INOUT_IDX); + Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); + + if (auto &disable_bias = std::get(*layer_impl_props); + disable_bias.empty() || disable_bias.get() == false) { + Tensor &djdb = context.getWeightGrad(weight_idx[FC_RMSParams::bias]); + + if (context.isGradientFirstAccess(weight_idx[FC_RMSParams::bias])) { + derivative_.sum({0, 1, 2}, djdb); + } else { + /// @todo optimize below by adding beta to Tensor::sum + Tensor t = derivative_.sum({0, 1, 2}); + djdb.add_i(t); + } + } + + input_.dot_deriv_wrt_2( + djdw, derivative_, false, false, + !context.isGradientFirstAccess(weight_idx[FC_RMSParams::weight])); +} + +} /* namespace nntrainer */ diff --git a/nntrainer/layers/cl_layers/fused_fc_norm_cl.h b/nntrainer/layers/cl_layers/fused_fc_norm_cl.h new file mode 100644 index 0000000000..24afd1ce5d --- /dev/null +++ b/nntrainer/layers/cl_layers/fused_fc_norm_cl.h @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Debadri Samaddar + * + * @file fused_fc_norm_cl.h + * @date 7 May 2024 + * @brief This is Fully Connected Layer Class of Neural Network with OpenCl + * implementation + * @see https://github.com/nnstreamer/nntrainer + * @author Debadri Samaddar + * @bug No known bugs except for NYI items + * + */ + +#ifndef __FUSED_FC_RMS_LAYER_CL_H__ +#define __FUSED_FC_RMS_LAYER_CL_H__ +#ifdef __cplusplus + +#include +#include + +namespace nntrainer { + +namespace props { + +/** + * @brief FUSED_FC_RMS_NORM_GAMMA_INIT_GPU Initialization Enumeration + * Information + * + */ +class FUSED_FC_RMS_NORM_GAMMA_INIT_GPU final + : public ::nntrainer::EnumProperty<::nntrainer::props::InitializerInfo> { +public: + /** + * @brief Construct a RMS_NORM_GAMMA_INIT object + */ + FUSED_FC_RMS_NORM_GAMMA_INIT_GPU( + ::nntrainer::Initializer value = ::nntrainer::Initializer::ONES) { + set(value); + }; + using prop_tag = enum_class_prop_tag; + static constexpr const char *key = "gamma_initializer"; +}; +}; // namespace props + +/** + * @class Fused Fully Connected Layer with RMS Normalization Layer Class for + * @brief Fused Fully Connected Layer with RMS Normalization Layer Class for + */ +class FullyConnectedRMSNormLayerCl : public LayerImpl { +public: + /** + * @brief Constructor of Fused Fully Connected && RMS Norm Layer + */ + FullyConnectedRMSNormLayerCl(); + + /** + * @brief Destructor of Fused Fully Connected && RMS Norm Layer + */ + ~FullyConnectedRMSNormLayerCl() = default; + + /** + * @brief Move constructor. + * @param[in] FullyConnectedRMSNorm && + */ + FullyConnectedRMSNormLayerCl(FullyConnectedRMSNormLayerCl &&rhs) noexcept = + default; + + /** + * @brief Move assignment operator. + * @parma[in] rhs FullyConnectedLayer to be moved. + */ + FullyConnectedRMSNormLayerCl & + operator=(FullyConnectedRMSNormLayerCl &&rhs) = default; + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(InitLayerContext &context) override; + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned + * int from, unsigned int to, bool training) + */ + void incremental_forwarding(RunLayerContext &context, unsigned int from, + unsigned int to, bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(RunLayerContext &context) override; + + /** + * @copydoc Layer::calcGradient(RunLayerContext &context) + */ + void calcGradient(RunLayerContext &context) override; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods + * method) + */ + void exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const override; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { + return FullyConnectedRMSNormLayerCl::type; + }; + + /** + * @copydoc Layer::supportBackwarding() + */ + bool supportBackwarding() const override { return true; } + + /** + * @copydoc Layer::setProperty(const PropertyType type, const std::string + * &value) + */ + void setProperty(const std::vector &values) override; + + inline static const std::string type = "fully_connected_rmsNorm"; + +private: + std::tuple + fc_rms_props; /**< fc layer properties : unit - number of output neurons */ + std::array + weight_idx; /**< indices of the weights for FC layer */ + // std::array wt_idx; /**< indices of the weights for RMS + // layer */ std::tuple + // rmsnorm_props; /**< rmsnorm layer properties */ +}; +} // namespace nntrainer + +#endif /* __cplusplus */ +#endif /* __FUSED_FC_RMS_LAYER_CL_H__ */ diff --git a/nntrainer/layers/cl_layers/meson.build b/nntrainer/layers/cl_layers/meson.build index 57cec3e556..31cd3ffcb1 100644 --- a/nntrainer/layers/cl_layers/meson.build +++ b/nntrainer/layers/cl_layers/meson.build @@ -6,6 +6,7 @@ cl_layer_sources = [ 'rmsnorm_layer_cl.cpp', 'concat_cl.cpp', 'transpose_cl.cpp', + 'fused_fc_norm_cl.cpp', ] foreach s : cl_layer_sources diff --git a/nntrainer/opencl/opencl_loader.cpp b/nntrainer/opencl/opencl_loader.cpp index e3cf3e73c2..9e2c98ae3a 100644 --- a/nntrainer/opencl/opencl_loader.cpp +++ b/nntrainer/opencl/opencl_loader.cpp @@ -90,6 +90,7 @@ void LoadOpenCLFunctions(void *libopencl) { LoadFunction(clRetainCommandQueue); LoadFunction(clReleaseCommandQueue); LoadFunction(clReleaseMemObject); + LoadFunction(clWaitForEvents); } PFN_clGetPlatformIDs clGetPlatformIDs; @@ -117,5 +118,6 @@ PFN_clReleaseContext clReleaseContext; PFN_clRetainCommandQueue clRetainCommandQueue; PFN_clReleaseCommandQueue clReleaseCommandQueue; PFN_clReleaseMemObject clReleaseMemObject; +PFN_clWaitForEvents clWaitForEvents; } // namespace nntrainer::opencl diff --git a/nntrainer/opencl/opencl_loader.h b/nntrainer/opencl/opencl_loader.h index 0aa2a5cfd6..57eeaf05d0 100644 --- a/nntrainer/opencl/opencl_loader.h +++ b/nntrainer/opencl/opencl_loader.h @@ -151,6 +151,9 @@ typedef cl_int(CL_API_CALL *PFN_clReleaseCommandQueue)( typedef cl_int(CL_API_CALL *PFN_clReleaseMemObject)(cl_mem /**< memobj */); +typedef cl_int(CL_API_CALL *PFN_clWaitForEvents)( + cl_uint /* num_events */, const cl_event * /* event_list */); + extern PFN_clGetPlatformIDs clGetPlatformIDs; extern PFN_clGetDeviceIDs clGetDeviceIDs; extern PFN_clGetDeviceInfo clGetDeviceInfo; @@ -176,6 +179,7 @@ extern PFN_clReleaseContext clReleaseContext; extern PFN_clRetainCommandQueue clRetainCommandQueue; extern PFN_clReleaseCommandQueue clReleaseCommandQueue; extern PFN_clReleaseMemObject clReleaseMemObject; +extern PFN_clWaitForEvents clWaitForEvents; } // namespace nntrainer::opencl diff --git a/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp b/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp index dc9d284c99..385b826766 100644 --- a/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp @@ -119,6 +119,52 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result, bool trans, const float *mdata = m.getData(); float *rdata = result.getData(); + /// shortcut handling in case of vector + /// for vector, (1 * K) == (K * 1) in current memory layout... + /// and plaese note that N, K, M is a fixed place holder after considering + /// transpose. + /// For example, there is no case like (1 * K) X (1 * K) while + /// (1 * K) X (1 * M) can be a case + /// case1: (1 * K) X (K * 1) + if (M == 1 && N == 1) { + *rdata = dot_cl(data, mdata, K) + (*rdata); + } + /// case2: (M * K) X (K * 1) + else if (N == 1) { + if (trans) { + // printf("Inside dotcl 2nd if else n == 1 trans true dotcl\n"); + sgemv_cl(data, mdata, rdata, trans, dim2, dim1, lda); + } else { + // printf("Inside dotcl 2nd if else n == 1 trans false dotcl\n"); + sgemv_cl(data, mdata, rdata, trans, dim1, dim2, lda); + } + // trans ? sgemv_cl(data, mdata, rdata, trans, dim2, dim1, lda) + // : sgemv_cl(data, mdata, rdata, trans, dim1, dim2, lda); + } + /// case3: (1 * K) X (K * N) = 1 * N = R + /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K) + /// Effectively a translation of sgemv + else if (M == 1) { + if (trans_m) { + // printf("Inside dotcl 3rd if else m == 1 trans_m true dotcl\n"); + sgemv_cl(mdata, data, rdata, !trans_m, mdim1, mdim2, ldb); + } else { + // printf("Inside dotcl 3rd if else m == 1 trans_m false dotcl\n"); + sgemv_cl(mdata, data, rdata, !trans_m, mdim2, mdim1, ldb); + } + // trans_m ? sgemv_cl(mdata, data, rdata, !trans_m, mdim1, mdim2, ldb) + // : sgemv_cl(mdata, data, rdata, !trans_m, mdim2, mdim1, ldb); + } + /// case others: use gemm + else { + sgemm_cl(trans, trans_m, data, mdata, rdata, M, N, K, lda, ldb, ldc); + } + } else if (input.getDataType() == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + const _FP16 *data = input.getData<_FP16>(); + const _FP16 *mdata = m.getData<_FP16>(); + _FP16 *rdata = result.getData<_FP16>(); + /// shortcut handling in case of vector /// for vector, (1 * K) == (K * 1) in current memory layout... /// and plaese note that N, K, M is a fixed place holder after considering @@ -141,12 +187,201 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result, bool trans, trans_m ? sgemv_cl(mdata, data, rdata, !trans_m, mdim1, mdim2, ldb) : sgemv_cl(mdata, data, rdata, !trans_m, mdim2, mdim1, ldb); } - /// case others: use gemm + /// case others: use sgemm else { sgemm_cl(trans, trans_m, data, mdata, rdata, M, N, K, lda, ldb, ldc); } +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif + } +} + +// Returning type +Tensor fusedProcess(Tensor const &input, Tensor const &m, Tensor const &bias, + bool disable_bias_value, Tensor const &gamma, + const float epsilon, bool trans, bool trans_m) { + // printf("Inside fusedProcess\n"); + Tensor output("", input.getFormat(), input.getDataType()); + fusedProcess(input, m, output, bias, disable_bias_value, gamma, epsilon, + trans, trans_m); + return output; +} + +// for fusion of FC and RMS +void fusedProcess(Tensor const &input, Tensor const &m, Tensor &result, + Tensor const &bias, bool disable_bias_value, + Tensor const &gamma, const float epsilon, bool trans, + bool trans_m) { + + unsigned int dim1, dim2, mdim1, mdim2; + unsigned int bias_dim1 = bias.size(); + if (input.getFormat() == Tformat::NHWC) { + // printf("Inside void fusedProcess NHWC if\n"); + dim1 = input.batch() * input.height() * input.width(); + dim2 = input.channel(); + mdim1 = m.batch() * m.height() * m.width(); + mdim2 = m.channel(); + } else { + // printf("Inside void fusedProcess NHWC else\n"); + dim1 = input.batch() * input.channel() * input.height(); + dim2 = input.width(); + mdim1 = m.batch() * m.channel() * m.height(); + mdim2 = m.width(); + } + + unsigned int M, N, K, lda, ldb, ldc; + + if (!trans && !trans_m) { + // printf("Both false trans && trans_m fused \n"); + if (dim2 != mdim1) + throw std::runtime_error( + "Error: incompatible dimensions for dot product"); + K = mdim1; /** == dim2 */ + N = mdim2; + M = dim1; + if (input.getFormat() == Tformat::NHWC) { + CREATE_IF_EMPTY_DIMS(result, input.batch(), N, input.height(), + input.width(), + input.getTensorType()); // NHWC Result Tensor + } else { + CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), + input.height(), N, input.getTensorType()); + } + } else if (!trans && trans_m) { + // printf("trans is false and trans_m is true fused\n"); + if (dim2 != mdim2) + throw std::runtime_error( + "Error: incompatible dimensions for dot product"); + K = mdim2; /** == dim2 */ + N = mdim1; + M = dim1; + if (input.getFormat() == Tformat::NHWC) { + CREATE_IF_EMPTY_DIMS(result, input.batch(), N, input.height(), + input.width(), input.getTensorType()); + } else { + CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), + input.height(), N, input.getTensorType()); + } + } else if (trans && !trans_m) { + // printf("trans is true and trans_m is false fused\n"); + if (dim1 != mdim1) + throw std::runtime_error( + "Error: incompatible dimensions for dot product"); + K = mdim1; /** == dim1 */ + N = mdim2; + M = dim2; + if (input.getFormat() == Tformat::NHWC) { + CREATE_IF_EMPTY_DIMS(result, 1, N, M, 1, input.getTensorType()); + } else { + CREATE_IF_EMPTY_DIMS(result, 1, 1, M, N, input.getTensorType()); + } + } else { + // printf("trans is true and trans_m is true fused\n"); + if (dim1 != mdim2) + throw std::runtime_error( + "Error: incompatible dimensions for dot product"); + K = mdim2; /** == dim1 */ + N = mdim1; + M = dim2; + if (input.getFormat() == Tformat::NHWC) { + CREATE_IF_EMPTY_DIMS(result, 1, N, M, 1, input.getTensorType()); + } else { + CREATE_IF_EMPTY_DIMS(result, 1, 1, M, N, input.getTensorType()); + } + } + + lda = dim2; + ldb = mdim2; + ldc = + (input.getFormat() == Tformat::NHWC) ? result.channel() : result.width(); + + bool isAdditionPossible = + (result.getDim() == bias.getDim()) || + (result.getDim() != bias.getDim() && bias.batch() == 1 && + result.channel() == bias.channel() && result.height() == bias.height() && + result.width() == bias.width()); + + // printf("Is Addition Possible : %s\n", isAdditionPossible ? "Yes" : "No"); + + if (input.getDataType() == ml::train::TensorDim::DataType::FP32) { + const float *data = input.getData(); + const float *mdata = m.getData(); + float *rdata = result.getData(); + int res_b = result.batch(); + int res_c = result.channel(); + int res_h = result.height(); + int res_w = result.width(); + + int bias_b = bias.batch(); + int bias_c = bias.channel(); + int bias_h = bias.height(); + int bias_w = bias.width(); + + // printf("Result Tensor Dimensions -> Batch : %d Channel : %d Height : %d + // Width : %d\n", res_b, res_c, res_h, res_w); printf("Bias Tensor + // Dimensions -> Batch : %d Channel : %d Height : %d Width : %d\n", bias_b, + // bias_c, bias_h, bias_w); + + const float *gdata = gamma.getData(); + const float *bdata = bias.getData(); + + /// shortcut handling in case of vector + /// for vector, (1 * K) == (K * 1) in current memory layout... + /// and plaese note that N, K, M is a fixed place holder after considering + /// transpose. + /// For example, there is no case like (1 * K) X (1 * K) while + /// (1 * K) X (1 * M) can be a case + /// case1: (1 * K) X (K * 1) + if (M == 1 && N == 1) { + printf("Inside fused 1st if\n"); + // TO-DO + // fused_dot_cl_rms(data, mdata, rdata, gdata, epsilon, K); + } + /// case2: (M * K) X (K * 1) + else if (N == 1) { + // printf("Inside fused 2nd if else\n"); + if (trans) { + // printf("Inside fused 2nd if else n == 1 trans true fused\n"); + fused_sgemv_cl_rms(data, mdata, rdata, gdata, bdata, isAdditionPossible, + epsilon, !trans_m, disable_bias_value, dim2, dim1, + bias_dim1, lda, res_b, res_c, res_h, res_w); + } else { + // printf("Inside fused 2nd if else n == 1 trans false fused\n"); + fused_sgemv_cl_rms(data, mdata, rdata, gdata, bdata, isAdditionPossible, + epsilon, !trans_m, disable_bias_value, dim1, dim2, + bias_dim1, lda, res_b, res_c, res_h, res_w); + } + } + /// case3: (1 * K) X (K * N) = 1 * N = R + /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K) + /// Effectively a translation of sgemv + else if (M == 1) { + // printf("Inside fused 3rd if else\n"); + if (trans_m) { + // printf("Inside fused 3rd if else m == 1 trans_m true fused\n"); + fused_sgemv_cl_rms(mdata, data, rdata, gdata, bdata, isAdditionPossible, + epsilon, !trans_m, disable_bias_value, mdim1, mdim2, + bias_dim1, ldb, res_b, res_c, res_h, res_w); + } else { + // printf("Inside fused 3rd if else m == 1 trans_m false fused\n"); + fused_sgemv_cl_rms(mdata, data, rdata, gdata, bdata, isAdditionPossible, + epsilon, !trans_m, disable_bias_value, mdim2, mdim1, + bias_dim1, ldb, res_b, res_c, res_h, res_w); + } + } + /// case others: use gemm + else { + // printf("Inside fused 4th else for sgemm\n"); + fused_sgemm_cl_rms(trans, trans_m, data, mdata, rdata, gdata, bdata, + isAdditionPossible, epsilon, disable_bias_value, M, N, + K, lda, ldb, ldc, bias_dim1, res_b, res_c, res_h, + res_w); + // sgemm_cl(trans, trans_m, data, mdata, rdata, M, N, K, lda, ldb, ldc); + } } else if (input.getDataType() == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 + // TO-DO for fusedProcess FP16 const _FP16 *data = input.getData<_FP16>(); const _FP16 *mdata = m.getData<_FP16>(); _FP16 *rdata = result.getData<_FP16>(); diff --git a/nntrainer/tensor/cl_operations/blas_kernel_interface.h b/nntrainer/tensor/cl_operations/blas_kernel_interface.h index 05f2068025..721fee9b1c 100644 --- a/nntrainer/tensor/cl_operations/blas_kernel_interface.h +++ b/nntrainer/tensor/cl_operations/blas_kernel_interface.h @@ -42,6 +42,40 @@ Tensor dotCl(Tensor const &input, Tensor const &m, bool trans = false, void dotCl(Tensor const &input, Tensor const &m, Tensor &result, bool trans = false, bool trans_m = false); +/** + * @brief fused process data and dimensions for OpenCL dot operation, addition + * and RMS + * @param[in] input Tensor + * @param[in] m Tensor + * @param[in] bias Tensor + * @param[in] disable_bias_value bool + * @param[in] gamma Tensor + * @param[in] epsilon float + * @param[in] trans bool + * @param[in] trans_m bool + */ +Tensor fusedProcess(Tensor const &input, Tensor const &m, Tensor const &bias, + bool disable_bias_value, Tensor const &gamma, + const float epsilon, bool trans = false, + bool trans_m = false); + +/** + * @brief fused process data and dimensions for OpenCL dot operation, addition + * and RMS + * @param[in] input Tensor + * @param[in] m Tensor + * @param[in] result Tensor + * @param[in] bias Tensor + * @param[in] disable_bias_value bool + * @param[in] gamma Tensor + * @param[in] epsilon float + * @param[in] trans bool + * @param[in] trans_m bool + */ +void fusedProcess(Tensor const &input, Tensor const &m, Tensor &result, + Tensor const &bias, bool disable_bias_value, + Tensor const &gamma, const float epsilon, bool trans = false, + bool trans_m = false); /** * @brief Process data and dimensions for OpenCL dot operation * @param[in] input Tensor diff --git a/nntrainer/tensor/cl_operations/blas_kernel_strings.h b/nntrainer/tensor/cl_operations/blas_kernel_strings.h index 57ba7f9e1a..264aab41be 100644 --- a/nntrainer/tensor/cl_operations/blas_kernel_strings.h +++ b/nntrainer/tensor/cl_operations/blas_kernel_strings.h @@ -21,6 +21,7 @@ static const std::string sgemv_cl_kernel_ = R"(__kernel void sgemv_cl(const __global float* A, const __global float* X, __global float* Y, unsigned int N, unsigned int lda) { unsigned int i; + // printf("Inside kernel sgemv_cl_kernel\n"); i = get_global_id(0); float y0 = 0.0f; for (unsigned int j = 0; j < N; j++) @@ -33,6 +34,7 @@ static const std::string sgemv_cl_noTrans_kernel_ = R"(__kernel void sgemv_cl_noTrans(const __global float* A, const __global float* X, __global float* Y, unsigned int N, unsigned int lda) { unsigned int i; + // printf("Inside kernel sgemv_cl_noTrans_kernel\n"); i = get_global_id(0); float y0 = 0.0f; for (unsigned int j = 0; j < N; j++) @@ -52,7 +54,7 @@ static const std::string dot_cl_kernel_ = static const std::string sgemm_cl_noTrans_kernel_ = R"(__kernel void sgemm_cl_noTrans(const __global float* A, const __global float* B, __global float* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) { - + // printf("Inside kernel sgemm_cl_noTrans_kernel\n"); unsigned int m = get_global_id(0); unsigned int n = get_global_id(1); float c = 0.0f; @@ -68,7 +70,7 @@ static const std::string sgemm_cl_noTrans_kernel_ = static const std::string sgemm_cl_transA_kernel_ = R"(__kernel void sgemm_cl_transA(const __global float* A, const __global float* B, __global float* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) { - + // printf("Inside kernel sgemm_cl_TransA_kernel\n"); unsigned int m = get_global_id(0); unsigned int n = get_global_id(1); float c = 0.0f; @@ -86,7 +88,7 @@ static const std::string sgemm_cl_transB_kernel_ = __global float *C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) { - + // printf("Inside kernel sgemm_cl_TransB_kernel\n"); unsigned int m = get_global_id(0); unsigned int n = get_global_id(1); float c = 0.0f; @@ -104,7 +106,7 @@ static const std::string sgemm_cl_transAB_kernel_ = __global float *C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) { - + // printf("Inside kernel sgemm_cl_TransAB_kernel\n"); unsigned int m = get_global_id(0); unsigned int n = get_global_id(1); float c = 0.0f; @@ -202,7 +204,63 @@ static const std::string transpose_cl_kernel_axis2 = } })"; +static const std::string rmsnorm_cl_kernel_new = + R"(__kernel void rmsnorm_cl(__global const float *input, __global float *output, __global const float *alpha, float epsilon, int B, int C, int H, int W){ + // Compute the corresponding batch, height, and channel indices + int n = get_global_id(0) / C; + int c = get_global_id(0) % C; + int h = get_global_id(1); + int index = ((n * C + c) * H + h) * W; + // Calculate RMS norm for the current channel, height, and batch + float sum_squares = 0.0f; + for (int j = 0; j < W; ++j) { + sum_squares += input[index+j] * input[index+j]; + } + sum_squares /= W; + float rms_norm = sqrt(sum_squares + epsilon); + // Each work item processes all width elements for its specific n, h, c + for (int w = 0; w < W; ++w) { + output[index+w] = (input[index+w] / rms_norm) * alpha[w]; + } +})"; + #ifdef ENABLE_FP16 + +static const std::string rmsnorm_cl_kernel_fp16_new = + R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + __kernel void rmsnorm_cl_fp16( + __global const half *input, // Input tensor + __global half *output, // Output tensor + __global const half *alpha, // Alpha values (one for each width) + half epsilon, + int B, // Number of batches + int C, // Number of channels + int H, // Height of feature map + int W // Width of feature map +) { + int global_id = get_global_id(0); // Get the global work item index + + // Compute the corresponding batch, height, and channel indices + int n = global_id / C; // Batch index + int c = global_id % C; // Height index + int h = get_global_id(1); // Channel index + int index = ((n * C + c) * H + h) * W; + + // Calculate RMS norm for the current channel, height, and batch + half sum_squares = 0.0f; + for (int j = 0; j < W; ++j) { + sum_squares += input[index+j] * input[index+j]; + } + sum_squares /= W; + half rms_norm = sqrt(sum_squares + epsilon); + // Each work item processes all width elements for its specific n, h, c + for (int w = 0; w < W; ++w) { + output[index+w] = (input[index+w] / rms_norm) * alpha[w]; + } +} +)"; + static const std::string sgemv_cl_kernel_fp16_ = R"( #pragma OPENCL EXTENSION cl_khr_fp16 : enable diff --git a/nntrainer/tensor/cl_operations/blas_kernels.cpp b/nntrainer/tensor/cl_operations/blas_kernels.cpp index 558111b5a8..6ee6c4fe34 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernels.cpp @@ -13,9 +13,274 @@ #include #include +#include namespace nntrainer { +void fused_sgemv_cl_rms(const float *matAdata, const float *vecXdata, + float *vecYdata, const float *gdata, const float *bdata, + bool isAdditionPossible, float epsilon, bool TransA, + bool isbias, unsigned int dim1, unsigned int dim2, + unsigned int bias_dim1, unsigned int lda, int b, int c, + int h, int w) { + + bool result = false; + + do { + // printf("Starting with sgemv dotcl\n"); + auto tt1 = std::chrono::high_resolution_clock::now(); + ClContext::SharedPtrClKernel kernel_sgemv_ptr; + + if (TransA) { + kernel_sgemv_ptr = + cl_context_ref.registerClKernel(sgemv_cl_kernel_, "sgemv_cl"); + } else { + kernel_sgemv_ptr = cl_context_ref.registerClKernel( + sgemv_cl_noTrans_kernel_, "sgemv_cl_noTrans"); + } + + if (!kernel_sgemv_ptr) { + printf("Failed to register sgemv kernel\n"); + break; + } + + size_t dim1_size = sizeof(float) * dim1; + size_t dim2_size = sizeof(float) * dim2; + opencl::Buffer inputA(cl_context_ref.context_inst_, + dim1 * dim2 * sizeof(float), true, nullptr); + + opencl::Buffer inputX(cl_context_ref.context_inst_, dim2_size, true, + nullptr); + + opencl::Buffer inOutY(cl_context_ref.context_inst_, dim1_size, true, + nullptr); + + result = inputA.WriteData(cl_context_ref.command_queue_inst_, matAdata); + if (!result) { + printf("Failed to write inputA data ind dotcl sgemv\n"); + break; + } + + result = inputX.WriteData(cl_context_ref.command_queue_inst_, vecXdata); + if (!result) { + printf("Failed to write inputX data in dotcl sgemv\n"); + break; + } + + result = inOutY.WriteData(cl_context_ref.command_queue_inst_, vecYdata); + if (!result) { + printf("Failed to write inOutY data in dotcl sgemv\n"); + break; + } + + result = kernel_sgemv_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument for inputA in dotcl sgemv\n"); + break; + } + + result = kernel_sgemv_ptr->SetKernelArguments(1, &inputX, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument for inputX in dotcl sgemv\n"); + break; + } + + result = kernel_sgemv_ptr->SetKernelArguments(2, &inOutY, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument for inOutY in dotcl sgemv\n"); + break; + } + + result = kernel_sgemv_ptr->SetKernelArguments(3, &dim2, sizeof(int)); + if (!result) { + printf("Failed to set argument for dim2 in dotcl sgemv\n"); + break; + } + + result = kernel_sgemv_ptr->SetKernelArguments(4, &lda, sizeof(int)); + if (!result) { + printf("Failed to set argument for lda in dotcl sgemv\n"); + break; + } + + const int work_groups_count[3] = {(int)dim1, 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_sgemv_ptr, work_groups_count, work_group_size); + + if (!result) { + printf("Failed to dispatch sgemv kernel\n"); + break; + } + + // printf("Done with sgemv dotcl\n"); + // const cl_event *ptr_sgemv = &sgemv_event; + // result = nntrainer::opencl::clWaitForEvents(1, ptr_sgemv); + // if (!result) { + // throw std::runtime_error("Failed to wait for SGEMV kernel event"); + // } + + if (isbias) { + if (isAdditionPossible) { + // cl_event add_event; + ClContext::SharedPtrClKernel kernel_addition_ptr = + cl_context_ref.registerClKernel(addition_cl_kernel_, "addition_cl"); + + if (!kernel_addition_ptr) { + printf("Failed to register addition kernel\n"); + break; + } + + size_t bias_size = sizeof(float) * bias_dim1; + // size_t dim2_size = sizeof(float) * size_res; // result size -> dim1 + + opencl::Buffer inputC(cl_context_ref.context_inst_, bias_size, true, + nullptr); + + result = inputC.WriteData(cl_context_ref.command_queue_inst_, bdata); + if (!result) { + printf("Failed to write inputC data in addition\n"); + break; + } + + result = + kernel_addition_ptr->SetKernelArguments(0, &inputC, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument for inputC in addition\n"); + break; + } + + result = + kernel_addition_ptr->SetKernelArguments(1, &inOutY, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument for inOutY in addition\n"); + break; + } + + result = + kernel_addition_ptr->SetKernelArguments(2, &bias_dim1, sizeof(int)); + if (!result) { + printf("Failed to set argument for bias_dim1 in addition\n"); + break; + } + + result = kernel_addition_ptr->SetKernelArguments(3, &dim1, sizeof(int)); + if (!result) { + printf("Failed to set argument for dim1 in addition\n"); + break; + } + + const int work_groups_count_add[3] = {(int)bias_dim1, 1, 1}; + const int work_group_size_add[3] = {32, 32, 1}; // test-value + + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_addition_ptr, work_groups_count_add, work_group_size_add); + + if (!result) { + printf("Failed to dispatch addition kernel\n"); + break; + } + } else { + // throw std::invalid_argument( + // "Error: Broadcasting not supported for these dimensions!"); + printf("Broadcasting not supported for these dimensions!\n"); + } + } + + // cl_event rms_event; + ClContext::SharedPtrClKernel kernel_rmsnorm_ptr = + cl_context_ref.registerClKernel(rmsnorm_cl_kernel_new, "rmsnorm_cl"); + + if (!kernel_rmsnorm_ptr) { + printf("Failed to register rmsnorm kernel\n"); + break; + } + + // for this the input is nothing but the result from above kernels, which is + // result only + opencl::Buffer gammabuf(cl_context_ref.context_inst_, w * sizeof(float), + true, nullptr); + + opencl::Buffer resultbuf( + cl_context_ref.context_inst_, dim1_size, true, + nullptr); // to store the data of the dot, add and rms + + result = gammabuf.WriteData(cl_context_ref.command_queue_inst_, gdata); + if (!result) { + printf("Failed to write gamma data in rmsnorm\n"); + break; + } + + result = kernel_rmsnorm_ptr->SetKernelArguments(0, &inOutY, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument for inputA in rmsnorm\n"); + break; + } + + result = + kernel_rmsnorm_ptr->SetKernelArguments(1, &resultbuf, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument for inOutY in rmsnorm\n"); + break; + } + + result = + kernel_rmsnorm_ptr->SetKernelArguments(2, &gammabuf, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument for gammabuf in rmsnorm\n"); + break; + } + + result = kernel_rmsnorm_ptr->SetKernelArguments(4, &b, sizeof(int)); + if (!result) { + printf("Failed to set argument for b in rmsnorm\n"); + break; + } + + result = kernel_rmsnorm_ptr->SetKernelArguments(3, &epsilon, sizeof(float)); + if (!result) { + printf("Failed to set argument for epsilon in rmsnorm\n"); + break; + } + + result = kernel_rmsnorm_ptr->SetKernelArguments(5, &c, sizeof(int)); + if (!result) { + printf("Failed to set argument for c in rmsnorm\n"); + break; + } + + result = kernel_rmsnorm_ptr->SetKernelArguments(6, &h, sizeof(int)); + if (!result) { + printf("Failed to set argument for h in rmsnorm\n"); + break; + } + result = kernel_rmsnorm_ptr->SetKernelArguments(7, &w, sizeof(int)); + if (!result) { + printf("Failed to set argument for w in rmsnorm\n"); + break; + } + + const int work_groups_count_rms[3] = {b * c, h, 1}; + const int work_group_size_rms[3] = {32, 32, 1}; // test-value + + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_rmsnorm_ptr, work_groups_count_rms, work_group_size_rms); + + if (!result) { + printf("Failed to dispatch rmsnorm kernel\n"); + break; + } + + // printf("Getting the output finally after dot, add, && rms sgemv!!\n"); + result = resultbuf.ReadData(cl_context_ref.command_queue_inst_, vecYdata); + if (!result) { + printf("Failed to read result data in the end\n"); + break; + } + } while (false); +} + void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata, bool TransA, unsigned int dim1, unsigned int dim2, unsigned int lda) { @@ -105,6 +370,79 @@ void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata, } while (false); } +// void fused_dot_cl_rms(const float *vecAdata, const float *vecXdata, float +// *rdata, const float *gdata, float epsilon, unsigned int dim1){ +// bool result = false; + +// float cl_ret = 0; +// do { +// ClContext::SharedPtrClKernel kernel_dot_ptr = +// cl_context_ref.registerClKernel(dot_cl_kernel_, "dot_cl"); +// if (!kernel_dot_ptr) { +// break; +// } + +// size_t dim1_size = sizeof(float) * dim1; + +// opencl::Buffer inputA(cl_context_ref.context_inst_, dim1_size, true, +// nullptr); + +// opencl::Buffer inputX(cl_context_ref.context_inst_, dim1_size, true, +// nullptr); + +// opencl::Buffer dotResult(cl_context_ref.context_inst_, sizeof(float), +// true, +// &cl_ret); + +// result = inputA.WriteData(cl_context_ref.command_queue_inst_, vecAdata); +// if (!result) { +// break; +// } + +// result = inputX.WriteData(cl_context_ref.command_queue_inst_, vecXdata); +// if (!result) { +// break; +// } + +// result = kernel_dot_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem)); +// if (!result) { +// break; +// } + +// result = kernel_dot_ptr->SetKernelArguments(1, &inputX, sizeof(cl_mem)); +// if (!result) { +// break; +// } + +// result = kernel_dot_ptr->SetKernelArguments(2, &dim1, sizeof(int)); +// if (!result) { +// break; +// } + +// result = kernel_dot_ptr->SetKernelArguments(3, &dotResult, +// sizeof(cl_mem)); if (!result) { +// break; +// } + +// const int work_groups_count[3] = {(int)dim1, 1, 1}; +// const int work_group_size[3] = {32, 32, 1}; // test-value + +// result = cl_context_ref.command_queue_inst_.DispatchCommand( +// kernel_dot_ptr, work_groups_count, work_group_size); +// if (!result) { +// break; +// } + +// *rdata += cl_ret; + +// result = dotResult.ReadData(cl_context_ref.command_queue_inst_, &cl_ret); +// if (!result) { +// break; +// } + +// } while (false); +// } + float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1) { bool result = false; @@ -178,6 +516,300 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1) { return cl_ret; } +void printMatrix(float *matrix, unsigned int rows, unsigned int cols) { + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + printf("%f ", matrix[i * cols + j]); + } + printf("\n"); + } +} + +void fused_sgemm_cl_rms(bool TransA, bool TransB, const float *A, + const float *B, float *C, const float *gdata, + const float *bdata, bool isAdditionPossible, + float epsilon, bool isbias, unsigned int M, + unsigned int N, unsigned int K, unsigned int lda, + unsigned int ldb, unsigned int ldc, + unsigned int bias_dim1, int b, int c, int h, int w) { + + bool result = false; + + do { + std::string kernel_func_; + std::string sgemm_cl_kernel_; + + if (!TransA && !TransB) { + kernel_func_ = "sgemm_cl_noTrans"; + sgemm_cl_kernel_ = sgemm_cl_noTrans_kernel_; + } else if (TransA && !TransB) { + kernel_func_ = "sgemm_cl_transA"; + sgemm_cl_kernel_ = sgemm_cl_transA_kernel_; + } else if (!TransA && TransB) { + kernel_func_ = "sgemm_cl_transB"; + sgemm_cl_kernel_ = sgemm_cl_transB_kernel_; + } else { + kernel_func_ = "sgemm_cl_transAB"; + sgemm_cl_kernel_ = sgemm_cl_transAB_kernel_; + } + + ClContext::SharedPtrClKernel kernel_sgemm_ptr = + cl_context_ref.registerClKernel(sgemm_cl_kernel_, kernel_func_); + if (!kernel_sgemm_ptr) { + printf("Failed to register sgemm kernel\n"); + break; + } + + // sizes will be same for transpose + size_t m_k_size = M * K * sizeof(float); + size_t k_n_size = K * N * sizeof(float); + size_t m_n_size = M * N * sizeof(float); + unsigned int dim1 = M * N; // result size + + opencl::Buffer inputA(cl_context_ref.context_inst_, m_k_size, true, + nullptr); + + opencl::Buffer inputB(cl_context_ref.context_inst_, k_n_size, true, + nullptr); + + opencl::Buffer inOutC(cl_context_ref.context_inst_, m_n_size, true, + nullptr); + + result = inputA.WriteData(cl_context_ref.command_queue_inst_, A); + if (!result) { + printf("Failed to write inputA data ind sgemm\n"); + break; + } + + result = inputB.WriteData(cl_context_ref.command_queue_inst_, B); + if (!result) { + printf("Failed to write inputB data in sgemm\n"); + break; + } + + result = inOutC.WriteData(cl_context_ref.command_queue_inst_, C); + if (!result) { + printf("Failed to write inOutY data in sgemm\n"); + break; + } + + result = kernel_sgemm_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument inputA in sgemm\n"); + break; + } + + result = kernel_sgemm_ptr->SetKernelArguments(1, &inputB, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument inputB in sgemm\n"); + break; + } + + result = kernel_sgemm_ptr->SetKernelArguments(2, &inOutC, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument inOutY in sgemm\n"); + break; + } + + result = kernel_sgemm_ptr->SetKernelArguments(3, &K, sizeof(int)); + if (!result) { + printf("Failed to set argument K in sgemm\n"); + break; + } + + result = kernel_sgemm_ptr->SetKernelArguments(4, &lda, sizeof(int)); + if (!result) { + printf("Failed to set argument lda in sgemm\n"); + break; + } + + result = kernel_sgemm_ptr->SetKernelArguments(5, &ldb, sizeof(int)); + if (!result) { + printf("Failed to set argument ldb in sgemm\n"); + break; + } + + result = kernel_sgemm_ptr->SetKernelArguments(6, &ldc, sizeof(int)); + if (!result) { + printf("Failed to set argument ldc in sgemm\n"); + break; + } + + const int work_groups_count[3] = {(int)M, (int)N, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_sgemm_ptr, work_groups_count, work_group_size); + + if (!result) { + printf("Failed to dispatch sgemm kernel\n"); + break; + } + + result = inOutC.ReadData(cl_context_ref.command_queue_inst_, C); + if (!result) { + printf("Failed to read result data after dptCL\n"); + break; + } + // printMatrix(C, M, N); + // printf("Done with sgemm dotcl\n"); + + if (isbias) { + if (isAdditionPossible) { + // cl_event add_event; + ClContext::SharedPtrClKernel kernel_addition_ptr = + cl_context_ref.registerClKernel(addition_cl_kernel_, "addition_cl"); + + if (!kernel_addition_ptr) { + printf("Failed to register addition kernel\n"); + break; + } + + size_t bias_size = sizeof(float) * bias_dim1; + // size_t dim2_size = sizeof(float) * size_res; // result size -> dim1 + + opencl::Buffer inputC(cl_context_ref.context_inst_, bias_size, true, + nullptr); + + result = inputC.WriteData(cl_context_ref.command_queue_inst_, bdata); + if (!result) { + printf("Failed to write inputC data in addition\n"); + break; + } + + result = + kernel_addition_ptr->SetKernelArguments(0, &inputC, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument for inputC in addition\n"); + break; + } + + result = + kernel_addition_ptr->SetKernelArguments(1, &inOutC, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument for inOutY in addition\n"); + break; + } + + result = + kernel_addition_ptr->SetKernelArguments(2, &bias_dim1, sizeof(int)); + if (!result) { + printf("Failed to set argument for bias_dim1 in addition\n"); + break; + } + + result = kernel_addition_ptr->SetKernelArguments(3, &dim1, sizeof(int)); + if (!result) { + printf("Failed to set argument for dim1 in addition\n"); + break; + } + + const int work_groups_count_add[3] = {(int)bias_dim1, 1, 1}; + const int work_group_size_add[3] = {32, 32, 1}; // test-value + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_addition_ptr, work_groups_count_add, work_group_size_add); + + if (!result) { + printf("Failed to dispatch addition kernel\n"); + break; + } + } else { + // throw std::invalid_argument( + // "Error: Broadcasting not supported for these dimensions!"); + printf("Broadcasting not supported for these dimensions!\n"); + } + } + + // cl_event rms_event; + ClContext::SharedPtrClKernel kernel_rmsnorm_ptr = + cl_context_ref.registerClKernel(rmsnorm_cl_kernel_new, "rmsnorm_cl"); + + if (!kernel_rmsnorm_ptr) { + printf("Failed to register rmsnorm kernel\n"); + break; + } + + // for this the input is nothing but the result from above kernels, which is + // result only + opencl::Buffer gammabuf(cl_context_ref.context_inst_, w * sizeof(float), + true, nullptr); + + opencl::Buffer resultbuf( + cl_context_ref.context_inst_, m_n_size, true, + nullptr); // to store the data of the dot, add and rms + + result = gammabuf.WriteData(cl_context_ref.command_queue_inst_, gdata); + if (!result) { + printf("Failed to write gamma data in rmsnorm\n"); + break; + } + result = kernel_rmsnorm_ptr->SetKernelArguments(0, &inOutC, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument for inputA in rmsnorm\n"); + break; + } + + result = + kernel_rmsnorm_ptr->SetKernelArguments(1, &resultbuf, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument for inOutY in rmsnorm\n"); + break; + } + + result = + kernel_rmsnorm_ptr->SetKernelArguments(2, &gammabuf, sizeof(cl_mem)); + if (!result) { + printf("Failed to set argument for gammabuf in rmsnorm\n"); + break; + } + + result = kernel_rmsnorm_ptr->SetKernelArguments(4, &b, sizeof(int)); + if (!result) { + printf("Failed to set argument for b in rmsnorm\n"); + break; + } + + result = kernel_rmsnorm_ptr->SetKernelArguments(3, &epsilon, sizeof(float)); + if (!result) { + printf("Failed to set argument for epsilon in rmsnorm\n"); + break; + } + + result = kernel_rmsnorm_ptr->SetKernelArguments(5, &c, sizeof(int)); + if (!result) { + printf("Failed to set argument for c in rmsnorm\n"); + break; + } + + result = kernel_rmsnorm_ptr->SetKernelArguments(6, &h, sizeof(int)); + if (!result) { + printf("Failed to set argument for h in rmsnorm\n"); + break; + } + result = kernel_rmsnorm_ptr->SetKernelArguments(7, &w, sizeof(int)); + if (!result) { + printf("Failed to set argument for w in rmsnorm\n"); + break; + } + const int work_groups_count_rms[3] = {b * c, h, 1}; + const int work_group_size_rms[3] = {32, 32, 1}; // test-value + + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_rmsnorm_ptr, work_groups_count_rms, work_group_size_rms); + if (!result) { + printf("Failed to dispatch rmsnorm kernel\n"); + break; + } + + result = resultbuf.ReadData(cl_context_ref.command_queue_inst_, C); + if (!result) { + printf("Failed to read result data in the end\n"); + break; + } + // printMatrix(C, M, N); + } while (false); +} + void sgemm_cl(bool TransA, bool TransB, const float *A, const float *B, float *C, unsigned int M, unsigned int N, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) { diff --git a/nntrainer/tensor/cl_operations/blas_kernels.h b/nntrainer/tensor/cl_operations/blas_kernels.h index 7a62148888..0f8d29b375 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels.h +++ b/nntrainer/tensor/cl_operations/blas_kernels.h @@ -24,6 +24,66 @@ namespace nntrainer { // get global cl_context to use in kernels static ClContext cl_context_ref; +/** + * @brief fused_sgemv computation : Y = A*X + Y + * @param[in] matAdata float * for Matrix A + * @param[in] vecXdata float * for Vector X + * @param[in] vecYdata float * for Vector Y + * @param[in] gdata float * for vector gamma + * @param[in] bdata float * for vector bias for addition + * @param[in] isAdditionPossible bool if addition possible + * @param[in] epsilon float epsilon value for RMSprop + * @param[in] TransA bool transpose + * @param[in] isbias bool if bias is present + * @param[in] dim1 number of A's columns + * @param[in] dim2 number of A's rows + * @param[in] bias_dim1 bias dimensions + * @param[in] lda number of X's columns + * @param[in] b batch of result + * @param[in] c channel of result + * @param[in] h height of result + * @param[in] w width of result + */ +void fused_sgemv_cl_rms(const float *matAdata, const float *vecXdata, + float *vecYdata, const float *gdata, const float *bdata, + bool isAdditionPossible, float epsilon, bool TransA, + bool isbias, unsigned int dim1, unsigned int dim2, + unsigned int bias_dim1, unsigned int lda, int b, int c, + int h, int w); + +/** + * @brief sgemm computation : Y = op(A)*op(B) + C, + * where op(X) is one of X or X**T + * @param[in] TransA bool transpose + * @param[in] TransB bool transpose + * @param[in] A float * for Matrix A + * @param[in] B float * for Matrix B + * @param[in] C float * for Matrix C + * @param[in] gdata float * for vector gamma + * @param[in] bdata float * for vector bias for addition + * @param[in] isAdditionPossible bool if addition possible + * @param[in] epsilon float epsilon value for RMSprop + * @param[in] isbias bool if bias is present + * @param[in] M number of op(A)'s and C's row + * @param[in] N number of op(B)'s and C's columns + * @param[in] K number of op(A)'s and columns and op(B)'s rows + * @param[in] lda number of A's columns + * @param[in] ldb number of B's columns + * @param[in] ldc number of C's columns + * @param[in] bias_dim1 bias dimensions + * @param[in] b batch of result + * @param[in] c channel of result + * @param[in] h height of result + * @param[in] w width of result + * @param[in] context RunLayerContext reference + */ +void fused_sgemm_cl_rms(bool TransA, bool TransB, const float *A, + const float *B, float *C, const float *gdata, + const float *bdata, bool isAdditionPossible, + float epsilon, bool isbias, unsigned int M, + unsigned int N, unsigned int K, unsigned int lda, + unsigned int ldb, unsigned int ldc, + unsigned int bias_dim1, int b, int c, int h, int w); /** * @brief sgemv computation : Y = A*X + Y * @param[in] matAdata float * for Matrix A diff --git a/test/include/nntrainer_test_util.h b/test/include/nntrainer_test_util.h index ecf266d817..d0ebfa0a84 100644 --- a/test/include/nntrainer_test_util.h +++ b/test/include/nntrainer_test_util.h @@ -191,6 +191,45 @@ class ScopedIni { } \ } while (0) +#define GEN_TEST_INPUT_RES(input, val) \ + do { \ + for (int i = 0; i < batch_res; ++i) { \ + for (int j = 0; j < channel_res; ++j) { \ + for (int k = 0; k < height_res; ++k) { \ + for (int l = 0; l < width_res; ++l) { \ + input.setValue(i, j, k, l, val); \ + } \ + } \ + } \ + } \ + } while (0) + +#define GEN_TEST_BIAS(input, val) \ + do { \ + for (int i = 0; i < batch_res; ++i) { \ + for (int j = 0; j < channel_res; ++j) { \ + for (int k = 0; k < height_res; ++k) { \ + for (int l = 0; l < width_res; ++l) { \ + input.setValue(i, j, k, l, val); \ + } \ + } \ + } \ + } \ + } while (0) + +#define GEN_TEST_INPUT_GAMMA(input, val) \ + do { \ + for (int i = 0; i < 1; ++i) { \ + for (int j = 0; j < 1; ++j) { \ + for (int k = 0; k < 1; ++k) { \ + for (int l = 0; l < width_res; ++l) { \ + input.setValue(i, j, k, l, val); \ + } \ + } \ + } \ + } \ + } while (0) + /** * @brief return a tensor filled with contant value with dimension */ diff --git a/test/input_gen/gen_layer_tests.py b/test/input_gen/gen_layer_tests.py index 180402fc62..aa92b50b31 100644 --- a/test/input_gen/gen_layer_tests.py +++ b/test/input_gen/gen_layer_tests.py @@ -21,7 +21,7 @@ @author Niket Agarwal @author Thummala Pallavi """ - +import numpy as np import warnings from recorder import ( record_single, @@ -978,3 +978,89 @@ def transpose_axis2(tensor, batch_size, input_channel, input_height, input_width transpose_layer_axis2 = tf.keras.layers.Lambda(lambda x: transpose_axis2(x, 2, 3, 3, 3)) record_single(transpose_layer_axis2, (2, 3, 3, 3), "transpose_axis2", input_type="float") record_single_fp16(transpose_layer_axis2, (2, 3, 3, 3), "transpose_fp16_axis2", input_type="float") + + class FusedLayer_Ind(tf.keras.layers.Layer): + def __init__(self, units, epsilon=1e-3, **kwargs): + super(FusedLayer_Ind, self).__init__(**kwargs) + self.units = units + self.epsilon = epsilon + # self.rms = RMSNorm() + + def build(self, input_shape): + self.fc_weights = self.add_weight( + shape=(input_shape[-1], self.units), + initializer='ones', + trainable=False, + name='fc_weights' + ) + self.fc_biases = self.add_weight( + shape=(self.units, ), + initializer='zeros', + trainable=False, + name='fc_biases' + ) + self.rms_gamma = self.add_weight( + shape=(self.units,), + initializer='ones', + trainable=False, + name='rms_gamma' + ) + super(FusedLayer_Ind, self).build(input_shape) + + def call(self, inputs): + print("Inputs: ", inputs) + print("\nWeights: ", np.array(self.fc_weights)) + fc_output = tf.matmul(inputs, self.fc_weights) + print("\nFC Output: ", fc_output) + print("\nBias: ", np.array(self.fc_biases)) + addition_output = fc_output + self.fc_biases + print("\nAddition output: ", addition_output) + + # **** Implement RMS normalization + mean_square = tf.reduce_mean(tf.square(addition_output), axis=[-1], keepdims=True) + # Compute the RMS value with epsilon for numerical stability + rms_value = tf.sqrt(mean_square + self.epsilon) + # Normalize inputs and scale by gamma + normalized_output = addition_output / rms_value * self.rms_gamma + print("\nFinal normalized output: ", normalized_output) + return normalized_output + # normalized_output = self.rms(addition_output) + # print("\nFinal normalized output: ", normalized_output) + # return normalized_output + + class FusedLayer(tf.keras.layers.Layer): + def __init__(self, units, epsilon=1e-3, **kwargs): + super(FusedLayer, self).__init__(**kwargs) + self.units = units + self.epsilon = epsilon + self.fc = K.layers.Dense(units, kernel_initializer='ones', bias_initializer=self.custom_bias_initializer) + # , activation=None, use_bias=True, kernel_initializer='ones', bias_initializer='zeros' + self.rms = RMSNorm(epsilon=self.epsilon) + + def custom_bias_initializer(self, shape, dtype=None): + # Create a bias vector with the same shape as the weight matrix + return tf.constant(np.random.uniform(0.00, 0.00, size=(1, 1, 2,2)), dtype=dtype) + + def call(self, inputs): + print("\nInitial input: ", inputs) + x = self.fc(inputs) + print("\nFC weights shape: ", self.fc.kernel.shape) + print("\nFC weights: ", self.fc.kernel.numpy()) + print("\nFC bias shape: ", self.fc.bias.shape) + print("\nFC bias: ", self.fc.bias.numpy()) + print("\nFC output: ", x) + x = self.rms(x) + print("\nRMS output: ", x) + return x + + # fused_layer_Ind = FusedLayer_Ind(units=1) + # fused_layer_single_plain = FusedLayer(units=2) + + fused_layer = FusedLayer_Ind(units=3) + record_single(fused_layer, (1, 1, 2, 3), "fused_plain") + inspect_file("fused_plain.nnlayergolden") + + # fused_layer_single_plain = FusedLayer(units=2) + # record_single(fused_layer_Ind, (1, 1, 2, 3), "fused_plain_Ind") + # record_single_fp16(fused_layer_single_plain, (1, 1, 1, 3), "fused_single_batch1") + diff --git a/test/jni/Android.mk b/test/jni/Android.mk index afc9a64780..ed2bab3b43 100644 --- a/test/jni/Android.mk +++ b/test/jni/Android.mk @@ -442,6 +442,7 @@ LOCAL_SRC_FILES := \ ../unittest/layers/unittest_layer_node.cpp \ ../unittest/layers/unittest_layers.cpp \ ../unittest/layers/unittest_layers_impl.cpp \ + ../unittest/layers/unittest_layers_fused_fc_rms_cl.cpp \ ../unittest/layers/unittest_layers_transpose_cl.cpp \ ../unittest/layers/unittest_layers_concat_cl.cpp \ ../unittest/layers/unittest_layers_swiglu_cl.cpp \ diff --git a/test/nntrainer_test_util.cpp b/test/nntrainer_test_util.cpp index 7ff307558d..ec646ff829 100644 --- a/test/nntrainer_test_util.cpp +++ b/test/nntrainer_test_util.cpp @@ -299,7 +299,8 @@ void sizeCheckedReadTensor(nntrainer::Tensor &t, std::ifstream &file, throw std::invalid_argument("Error: enable-fp16 is not enabled"); #endif } - + std::cout << "Inside sizeCheckedReadTensor func && string: " << error_msg + << std::endl; NNTR_THROW_IF(t.getDim().getDataLen() != sz, std::invalid_argument) << "[ReadFail] dimension does not match at " << error_msg << " sz: " << sz << " dimsize: " << t.getDim().getDataLen() << '\n'; diff --git a/test/unittest/layers/layers_golden_tests.cpp b/test/unittest/layers/layers_golden_tests.cpp index accf158b28..a19b3e40e0 100644 --- a/test/unittest/layers/layers_golden_tests.cpp +++ b/test/unittest/layers/layers_golden_tests.cpp @@ -134,6 +134,7 @@ static TensorPacks prepareTensors(const InitLayerContext &context, str_converter:: from_string(tensor_type[1])); weights_.emplace_back(spec_, true); + // std::cout<<"allocate weights if"< + * + * @file unittest_layers_fully_connected_cl.cpp + * @date 7 June 2024 + * @brief Fully Connected Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author Debadri Samaddar + * @bug No known bugs except for NYI items + */ +#include + +#include + +#include +#include +#include + +auto semantic_fused_gpu = LayerSemanticsParamType( + nntrainer::createLayer, + nntrainer::FullyConnectedRMSNormLayerCl::type, {"unit=1"}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1); + +GTEST_PARAMETER_TEST(FusedLayerGPU, LayerSemanticsGpu, + ::testing::Values(semantic_fused_gpu)); + +auto fused_gpu_plain = LayerGoldenTestParamType( + nntrainer::createLayer, {"unit=3"}, + "1:1:2:3", "fused_plain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, + "nchw", "fp32", "fp32"); + +// auto fused_gpu_plain_Ind = LayerGoldenTestParamType( +// nntrainer::createLayer, {"unit=1", +// "epsilon=0.001"}, "1:1:2:3", "fused_plain_Ind.nnlayergolden", +// LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32"); + +// auto fused_gpu_single_batch = LayerGoldenTestParamType( +// nntrainer::createLayer, {"unit=2", +// "epsilon=0.001"}, "1:1:1:3", "fused_single_batch1.nnlayergolden", +// LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32"); + +// auto fc_gpu_no_decay = LayerGoldenTestParamType( +// nntrainer::createLayer, +// {"unit=5", "weight_decay=0.0", "bias_decay=0.0"}, "3:1:1:10", +// "fc_plain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw", +// "fp32", "fp32"); + +// auto fc_gpu_plain_nhwc = LayerGoldenTestParamType( +// nntrainer::createLayer, {"unit=5"}, +// "3:10:1:1", "fc_plain.nnlayergolden", +// LayerGoldenTestParamOptions::SKIP_CALC_DERIV | +// LayerGoldenTestParamOptions::SKIP_CALC_GRAD | +// LayerGoldenTestParamOptions::USE_INC_FORWARD, +// "nhwc", "fp32", "fp32"); + +// auto fc_gpu_single_batch_nhwc = LayerGoldenTestParamType( +// nntrainer::createLayer, {"unit=4"}, +// "1:10:1:1", "fc_single_batch.nnlayergolden", +// LayerGoldenTestParamOptions::SKIP_CALC_DERIV | +// LayerGoldenTestParamOptions::SKIP_CALC_GRAD, +// "nhwc", "fp32", "fp32"); + +// auto fc_gpu_no_decay_nhwc = LayerGoldenTestParamType( +// nntrainer::createLayer, +// {"unit=5", "weight_decay=0.0", "bias_decay=0.0"}, "3:10:1:1", +// "fc_plain.nnlayergolden", +// LayerGoldenTestParamOptions::SKIP_CALC_DERIV | +// LayerGoldenTestParamOptions::SKIP_CALC_GRAD, +// "nhwc", "fp32", "fp32"); + +GTEST_PARAMETER_TEST(FusedLayerGPU, LayerGoldenTest, + ::testing::Values(fused_gpu_plain)); + +// #ifdef ENABLE_FP16 +// auto fc_gpu_basic_plain_w16a16 = LayerGoldenTestParamType( +// nntrainer::createLayer, {"unit=5"}, +// "3:1:1:10", "fc_plain_w16a16.nnlayergolden", +// LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +// auto fc_gpu_basic_single_batch_w16a16 = LayerGoldenTestParamType( +// nntrainer::createLayer, {"unit=4"}, +// "1:1:1:10", "fc_single_batch_w16a16.nnlayergolden", +// LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +// auto fc_gpu_basic_no_decay_w16a16 = LayerGoldenTestParamType( +// nntrainer::createLayer, +// {"unit=5", "weight_decay=0.0", "bias_decay=0.0"}, "3:1:1:10", +// "fc_plain_w16a16.nnlayergolden", +// LayerGoldenTestParamOptions::SKIP_CALC_DERIV | +// LayerGoldenTestParamOptions::SKIP_CALC_GRAD | +// LayerGoldenTestParamOptions::USE_INC_FORWARD, +// "nchw", "fp16", "fp16"); + +// GTEST_PARAMETER_TEST(FullyConnectedGPU16, LayerGoldenTest, +// ::testing::Values(fc_gpu_basic_plain_w16a16, +// fc_gpu_basic_single_batch_w16a16, +// fc_gpu_basic_no_decay_w16a16)); +// #endif diff --git a/test/unittest/unittest_blas_kernels_cl.cpp b/test/unittest/unittest_blas_kernels_cl.cpp index ab1c8a03fa..399fd3a020 100644 --- a/test/unittest/unittest_blas_kernels_cl.cpp +++ b/test/unittest/unittest_blas_kernels_cl.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #define EXPECT_IN_RANGE(VAL, MIN, MAX) \ @@ -29,7 +30,9 @@ using namespace nntrainer; static void setUpGpuContext() { auto &ac = nntrainer::ClContext::Global(); } -TEST(blas_kernels, dotCL_sgemv_M_1_1) { +// else if m == 1, and either trans true +TEST(blas_kernels, fused_M_1_1) { + // //if M == 1 and trans, Batch: 1, Channel: 1, Height: 1, Width: 2048 setUpGpuContext(); int batch = 1; int channel = 1; @@ -39,6 +42,81 @@ TEST(blas_kernels, dotCL_sgemv_M_1_1) { int height_b = 2048; int width_b = 768; + int batch_res = 1; + int channel_res = 1; + int height_res = 1; + int width_res = height_b; + + bool transA = false; + bool transB = true; + + const float epsilon = 1e-3 * width; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + nntrainer::Tensor bias_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor outRMS_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor gamma_fp32(1, 1, 1, width_res, t_type_nchw_fp32); + + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + GEN_TEST_BIAS(bias_fp32, 1); + GEN_TEST_INPUT_GAMMA(gamma_fp32, 1); + + nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); + add_i_cl(C, bias_fp32); + RMSNormLayerCl obj; + obj.rmsnormProcess(C, outRMS_fp32, gamma_fp32, epsilon); + + bool disable_bias_value = true; + nntrainer::Tensor res_fp32 = + fusedProcess(A_fp32, B_fp32, bias_fp32, disable_bias_value, gamma_fp32, + epsilon, transA, transB); + + float mseErrorNeon = mse( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + double cosSimNeon = cosine_similarity( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + // printf("Size of res_fp32: %zu, Size of outRMS_fp32: %zu\n", + // res_fp32.size(), outRMS_fp32.size()); std::cout << "\res_fp32 and + // outRMS_fp32 before rotary embedding:" << std::endl; for (unsigned int i = + // 0; i < res_fp32.size(); ++i) { + // std::cout << "Element " << i << " -> " << *(res_fp32.getData() + + // i) + // <<"\t"<<*(outRMS_fp32.getData() + i)<< std::endl; + // } + + EXPECT_IN_RANGE(mseErrorNeon, 0, 1e-6); + EXPECT_IN_RANGE((float)cosSimNeon, 0.999999, 1); +} + +TEST(blas_kernels, dotCL_sgemv_M_1_1) { + // if M == 1 and trans, Batch: 1, Channel: 1, Height: 1, Width: 2048 + // setUpGpuContext(); + int batch = 1; + int channel = 1; + int height = 1; + int width = 768; + + int height_b = 2048; + int width_b = 768; + bool transA = false; bool transB = true; @@ -61,7 +139,10 @@ TEST(blas_kernels, dotCL_sgemv_M_1_1) { alpha); nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); + nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + // printf("if M == 1 and trans, Batch: %zu, Channel: %zu, Height: %zu, Width: + // %zu\n", C_fp32.batch(), C_fp32.channel(), C_fp32.height(), C_fp32.width()); float mseErrorNeon = mse(C.getData(), C_fp32.getData(), C.size()); @@ -75,7 +156,79 @@ TEST(blas_kernels, dotCL_sgemv_M_1_1) { EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); } +// else if m == 1, and trans false +TEST(blas_kernels, fused_M_1_2) { + // //if M == 1 and trans false, Batch: 1, Channel: 1, Height: 1, Width: 2048 + int batch = 1; + int channel = 1; + int height = 1; + int width = 768; + + int height_b = 768; + int width_b = 2048; + + int batch_res = 1; + int channel_res = 1; + int height_res = 1; + int width_res = width_b; + + bool transA = false; + bool transB = false; + + const float epsilon = 1e-3 * width; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + nntrainer::Tensor bias_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor outRMS_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor gamma_fp32(1, 1, 1, width_res, t_type_nchw_fp32); + + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + GEN_TEST_INPUT_RES(bias_fp32, + ((i * (batch_res * height_res * channel_res) + + j * (batch_res * height_res) + k * (width_res) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_GAMMA(gamma_fp32, 1); + + nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); + add_i_cl(C, bias_fp32); + RMSNormLayerCl obj; + obj.rmsnormProcess(C, outRMS_fp32, gamma_fp32, epsilon); + + bool disable_bias_value = true; + nntrainer::Tensor res_fp32 = + fusedProcess(A_fp32, B_fp32, bias_fp32, disable_bias_value, gamma_fp32, + epsilon, transA, transB); + + float mseErrorNeon = mse( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + double cosSimNeon = cosine_similarity( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + EXPECT_IN_RANGE(mseErrorNeon, 0, 1e-6); + EXPECT_IN_RANGE((float)cosSimNeon, 0.999999, 1); +} + TEST(blas_kernels, dotCL_sgemv_M_1_2) { + // if M == 1 and !trans, Batch: 1, Channel: 1, Height: 1, Width: 2048 int batch = 1; int channel = 1; int height = 1; @@ -120,6 +273,84 @@ TEST(blas_kernels, dotCL_sgemv_M_1_2) { EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); } +// else if n == 1, and trans true +TEST(blas_kernels, fused_N_1_1) { + // if N == 1 and trans, Batch: 1, Channel: 1, Height: 2048, Width: 1 + // setUpGpuContext(); + int batch = 1; + int channel = 1; + int height = 768; + int width = 2048; + + int height_b = 768; + int width_b = 1; + + int batch_res = 1; + int channel_res = 1; + int height_res = width; + int width_res = 1; + + bool transA = true; + bool transB = false; + + const float epsilon = 1e-3 * width; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + nntrainer::Tensor bias_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor outRMS_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor gamma_fp32(1, 1, 1, width_res, t_type_nchw_fp32); + + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + GEN_TEST_BIAS(bias_fp32, 1); + GEN_TEST_INPUT_GAMMA(gamma_fp32, 1); + + bool disable_bias_value = true; + nntrainer::Tensor res_fp32 = + fusedProcess(A_fp32, B_fp32, bias_fp32, disable_bias_value, gamma_fp32, + epsilon, transA, transB); + + nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); + add_i_cl(C, bias_fp32); + RMSNormLayerCl obj; + obj.rmsnormProcess(C, outRMS_fp32, gamma_fp32, epsilon); + + float mseErrorNeon = mse( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + double cosSimNeon = cosine_similarity( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + // printf("Size of res_fp32: %zu, Size of outRMS_fp32: %zu\n", + // res_fp32.size(), outRMS_fp32.size()); std::cout << "\res_fp32 and + // outRMS_fp32 after everything" << std::endl; for (unsigned int i = 0; i < + // res_fp32.size(); ++i) { + // std::cout << "Element " << i << " -> " << *(res_fp32.getData() + + // i) + // <<"\t"<<*(outRMS_fp32.getData() + i)<< std::endl; + // } + + EXPECT_IN_RANGE(mseErrorNeon, 0, 1e-6); + EXPECT_IN_RANGE((float)cosSimNeon, 0.999999, 1); +} + +// if N == 1 and trans, Batch: 1, Channel: 1, Height: 2048, Width: 1 TEST(blas_kernels, dotCL_sgemv_N_1_1) { int batch = 1; int channel = 1; @@ -165,6 +396,78 @@ TEST(blas_kernels, dotCL_sgemv_N_1_1) { EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); } +// // else if n == 1, and trans false +TEST(blas_kernels, fused_N_1_2) { + // if N == 1 and !trans, Batch: 1, Channel: 1, Height: 768, Width: 1 + int batch = 1; + int channel = 1; + int height = 768; + int width = 2048; + + int height_b = 2048; + int width_b = 1; + + int batch_res = 1; + int channel_res = 1; + int height_res = height; + int width_res = 1; + + bool transA = false; + bool transB = false; + + const float epsilon = 1e-3 * width; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + nntrainer::Tensor bias_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor outRMS_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor gamma_fp32(1, 1, 1, width_res, t_type_nchw_fp32); + + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + GEN_TEST_INPUT_RES(bias_fp32, + ((i * (batch_res * height_res * channel_res) + + j * (batch_res * height_res) + k * (width_res) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_GAMMA(gamma_fp32, 1); + + bool disable_bias_value = true; + nntrainer::Tensor res_fp32 = + fusedProcess(A_fp32, B_fp32, bias_fp32, disable_bias_value, gamma_fp32, + epsilon, transA, transB); + + nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); + add_i_cl(C, bias_fp32); + RMSNormLayerCl obj; + obj.rmsnormProcess(C, outRMS_fp32, gamma_fp32, epsilon); + + // nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + + float mseErrorNeon = mse( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + double cosSimNeon = cosine_similarity( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + EXPECT_IN_RANGE(mseErrorNeon, 0, 1e-6); + EXPECT_IN_RANGE((float)cosSimNeon, 0.999999, 1); +} TEST(blas_kernels, dotCL_sgemv_N_1_2) { int batch = 1; int channel = 1; @@ -197,6 +500,9 @@ TEST(blas_kernels, dotCL_sgemv_N_1_2) { nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + // if N == 1 and !trans, Batch: 1, Channel: 1, Height: 768, Width: 1 + // printf("if N == 1 and !trans, Batch: %zu, Channel: %zu, Height: %zu, Width: + // %zu\n", C_fp32.batch(), C_fp32.channel(), C_fp32.height(), C_fp32.width()); float mseErrorNeon = mse(C.getData(), C_fp32.getData(), C.size()); @@ -286,7 +592,100 @@ TEST(nntrainer_Tensor, multiply_i) { EXPECT_IN_RANGE(cosSimNeon, 0.99, 1); } +TEST(nntrainer_Tensor, fused_gemm_50_768_1024_noTrans) { + // if N == 1 and !transA & !transB, Batch: 1, Channel: 1, Height: 50, Width: + // 1024 + /// @note GEMM : A X B = C + + int batch = 1; + int channel = 1; + int height = 50; + int width = 768; + + int height_b = 768; + int width_b = 1024; + + // int batch = 1; + // int channel = 1; + // int height = 2; + // int width = 3; + + // int height_b = 3; + // int width_b = 4; + + int batch_res = batch; + int channel_res = channel; + int height_res = height; + int width_res = width_b; + + bool transA = false; + bool transB = false; + + const float epsilon = 1e-3 * width; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + + nntrainer::Tensor bias_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor outRMS_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor gamma_fp32(1, 1, 1, width_res, t_type_nchw_fp32); + + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + GEN_TEST_BIAS(bias_fp32, 1); + GEN_TEST_INPUT_GAMMA(gamma_fp32, 1); + + // auto t1 = std::chrono::high_resolution_clock::now(); + // nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); + // // for(int i=0;i<999;i++) + // // C = dotCl(A_fp32, B_fp32, transA, transB); + // auto t2 = std::chrono::high_resolution_clock::now(); + // auto ms_int = std::chrono::duration_cast(t2 - + // t1); std::cout<<"Timing SGEMV dotCl: "<( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + double cosSimNeon = cosine_similarity( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + EXPECT_IN_RANGE(mseErrorNeon, 0, 1e-9); + EXPECT_IN_RANGE((float)cosSimNeon, 0.999999, 1); +} + TEST(nntrainer_Tensor, dot_gemm_50_768_1024_noTrans) { + // if N == 1 and !transA & !transB, Batch: 1, Channel: 1, Height: 50, Width: + // 1024 /// @note GEMM : A X B = C int batch = 1; @@ -334,6 +733,8 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_1024_noTrans) { nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + // printf("if N == 1 and !transA & !transB, Batch: %zu, Channel: %zu, Height: + // %zu, Width: %zu\n", C.batch(), C.channel(), C.height(), C.width()); float mseErrorNeon = mse(C.getData(), C_fp32.getData(), C.size()); @@ -347,7 +748,81 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_1024_noTrans) { EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); } +TEST(nntrainer_Tensor, fused_gemm_50_768_2048_transB) { + // if N == 1 and !transA & transB, Batch: 1, Channel: 1, Height: 50, Width: + // 2048 + /// @note GEMM : A X B = C + + int batch = 1; + int channel = 1; + int height = 50; + int width = 768; + + int height_b = 2048; + int width_b = 768; + + int batch_res = batch; + int channel_res = channel; + int height_res = height; + int width_res = height_b; + + bool transA = false; + bool transB = true; + + const float epsilon = 1e-3 * width; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + nntrainer::Tensor bias_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor outRMS_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor gamma_fp32(1, 1, 1, width_res, t_type_nchw_fp32); + + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + GEN_TEST_BIAS(bias_fp32, 1); + GEN_TEST_INPUT_GAMMA(gamma_fp32, 1); + + bool disable_bias_value = true; + + nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); + add_i_cl(C, bias_fp32); + RMSNormLayerCl obj; + obj.rmsnormProcess(C, outRMS_fp32, gamma_fp32, epsilon); + + nntrainer::Tensor res_fp32 = + fusedProcess(A_fp32, B_fp32, bias_fp32, disable_bias_value, gamma_fp32, + epsilon, transA, transB); + float mseErrorNeon = mse( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + double cosSimNeon = cosine_similarity( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + EXPECT_IN_RANGE(mseErrorNeon, 0, 1e-9); + EXPECT_IN_RANGE((float)cosSimNeon, 0.999999, 1); +} + TEST(nntrainer_Tensor, dot_gemm_50_768_2048_transB) { + // if N == 1 and !transA & transB, Batch: 1, Channel: 1, Height: 50, Width: + // 2048 /// @note GEMM : A X B = C int batch = 1; @@ -358,6 +833,11 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_2048_transB) { int height_b = 2048; int width_b = 768; + int batch_res = batch; + int channel_res = channel; + int height_res = height; + int width_res = height_b; + bool transA = false; bool transB = true; @@ -395,6 +875,148 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_2048_transB) { nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + // printf("if N == 1 and !transA & transB, Batch: %zu, Channel: %zu, Height: + // %zu, Width: %zu\n", C.batch(), C.channel(), C.height(), C.width()); + + float mseErrorNeon = + mse(C.getData(), C_fp32.getData(), C.size()); + + double cosSimNeon = cosine_similarity( + C.getData(), C_fp32.getData(), C.size()); + + const float epsilon = 1e-3 * width; + + EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon); + EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); +} + +TEST(nntrainer_Tensor, fused_ResultThere) { + // if N == 1 and transA & !transB, Batch: 1, Channel: 1, Height: 50, Width: + // 1024 + /// @note GEMM : A X B = C + + int batch = 1; + int channel = 1; + int height = 3; + int width = 2; + + int height_b = 3; + int width_b = 4; + + int batch_res = batch; + int channel_res = channel; + int height_res = width; + int width_res = width_b; + + bool transA = true; + bool transB = false; + + const float epsilon = 1e-3 * width; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + nntrainer::Tensor res_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor out_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor bias_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor outRMS_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor gamma_fp32(1, 1, 1, width_res, t_type_nchw_fp32); + + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + GEN_TEST_INPUT_RES(res_fp32, 2); + + GEN_TEST_INPUT_RES(out_fp32, 2); + + GEN_TEST_BIAS(bias_fp32, 1); + GEN_TEST_INPUT_GAMMA(gamma_fp32, 1); + + bool disable_bias_value = true; + fusedProcess(A_fp32, B_fp32, res_fp32, bias_fp32, disable_bias_value, + gamma_fp32, epsilon, transA, transB); + + dotCl(A_fp32, B_fp32, out_fp32, transA, transB); + add_i_cl(out_fp32, bias_fp32); + RMSNormLayerCl obj; + obj.rmsnormProcess(out_fp32, outRMS_fp32, gamma_fp32, epsilon); + + float mseErrorNeon = mse( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + double cosSimNeon = cosine_similarity( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + EXPECT_IN_RANGE(mseErrorNeon, 0, 1e-9); + EXPECT_IN_RANGE((float)cosSimNeon, 0.999999, 1); +} +TEST(nntrainer_Tensor, dot_resultThere) { + // if N == 1 and transA & !transB, Batch: 1, Channel: 1, Height: 50, Width: + // 1024 + /// @note GEMM : A X B = C + + int batch = 1; + int channel = 1; + int height = 768; + int width = 50; + + int height_b = 768; + int width_b = 1024; + + int batch_res = batch; + int channel_res = channel; + int height_res = width; + int width_res = width_b; + + bool transA = true; + bool transB = false; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + nntrainer::Tensor C(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor out_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + dotCl(A_fp32, B_fp32, C, transA, transB); + nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, out_fp32, transA, transB); + // printf("if N == 1 and transA & !transB, Batch: %zu, Channel: %zu, Height: + // %zu, Width: %zu\n", C.batch(), C.channel(), C.height(), C.width()); float mseErrorNeon = mse(C.getData(), C_fp32.getData(), C.size()); @@ -408,7 +1030,80 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_2048_transB) { EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); } +TEST(nntrainer_Tensor, fused_gemm_50_768_1024_transA) { + // if N == 1 and transA & !transB, Batch: 1, Channel: 1, Height: 50, Width: + // 1024 + /// @note GEMM : A X B = C + + int batch = 1; + int channel = 1; + int height = 768; + int width = 50; + + int height_b = 768; + int width_b = 1024; + + int batch_res = batch; + int channel_res = channel; + int height_res = width; + int width_res = width_b; + + bool transA = true; + bool transB = false; + + const float epsilon = 1e-3 * width; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + nntrainer::Tensor bias_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor outRMS_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor gamma_fp32(1, 1, 1, width_res, t_type_nchw_fp32); + + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + GEN_TEST_BIAS(bias_fp32, 1); + GEN_TEST_INPUT_GAMMA(gamma_fp32, 1); + + bool disable_bias_value = true; + nntrainer::Tensor res_fp32 = + fusedProcess(A_fp32, B_fp32, bias_fp32, disable_bias_value, gamma_fp32, + epsilon, transA, transB); + + nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); + add_i_cl(C, bias_fp32); + RMSNormLayerCl obj; + obj.rmsnormProcess(C, outRMS_fp32, gamma_fp32, epsilon); + + float mseErrorNeon = mse( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + double cosSimNeon = cosine_similarity( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + EXPECT_IN_RANGE(mseErrorNeon, 0, 1e-9); + EXPECT_IN_RANGE((float)cosSimNeon, 0.999999, 1); +} TEST(nntrainer_Tensor, dot_gemm_50_768_1024_transA) { + // if N == 1 and transA & !transB, Batch: 1, Channel: 1, Height: 50, Width: + // 1024 /// @note GEMM : A X B = C int batch = 1; @@ -419,6 +1114,11 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_1024_transA) { int height_b = 768; int width_b = 1024; + // int batch_res = batch; + // int channel_res = channel; + // int height_res = width; + // int width_res = width_b; + bool transA = true; bool transB = false; @@ -456,6 +1156,8 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_1024_transA) { nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + // printf("if N == 1 and transA & !transB, Batch: %zu, Channel: %zu, Height: + // %zu, Width: %zu\n", C.batch(), C.channel(), C.height(), C.width()); float mseErrorNeon = mse(C.getData(), C_fp32.getData(), C.size()); @@ -469,7 +1171,81 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_1024_transA) { EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); } +TEST(nntrainer_Tensor, fused_gemm_50_768_2048_transAB) { + // if N == 1 and transA & transB, Batch: 1, Channel: 1, Height: 50, Width: + // 2048 + /// @note GEMM : A X B = C + + int batch = 1; + int channel = 1; + int height = 768; + int width = 50; + + int height_b = 2048; + int width_b = 768; + + int batch_res = batch; + int channel_res = channel; + int height_res = width; + int width_res = height_b; + + bool transA = true; + bool transB = true; + + const float epsilon = 1e-3 * width; + + const float alpha = 1e-1; + const int MOD = 10; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + nntrainer::Tensor bias_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor outRMS_fp32(batch_res, channel_res, height_res, width_res, + t_type_nchw_fp32); + nntrainer::Tensor gamma_fp32(1, 1, 1, width_res, t_type_nchw_fp32); + + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + GEN_TEST_BIAS(bias_fp32, 1); + GEN_TEST_INPUT_GAMMA(gamma_fp32, 1); + + nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); + add_i_cl(C, bias_fp32); + RMSNormLayerCl obj; + obj.rmsnormProcess(C, outRMS_fp32, gamma_fp32, epsilon); + + bool disable_bias_value = true; + nntrainer::Tensor res_fp32 = + fusedProcess(A_fp32, B_fp32, bias_fp32, disable_bias_value, gamma_fp32, + epsilon, transA, transB); + + float mseErrorNeon = mse( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + double cosSimNeon = cosine_similarity( + res_fp32.getData(), outRMS_fp32.getData(), res_fp32.size()); + + EXPECT_IN_RANGE(mseErrorNeon, 0, 1e-9); + EXPECT_IN_RANGE((float)cosSimNeon, 0.999999, 1); +} + TEST(nntrainer_Tensor, dot_gemm_50_768_2048_transAB) { + // if N == 1 and transA & transB, Batch: 1, Channel: 1, Height: 50, Width: + // 2048 /// @note GEMM : A X B = C int batch = 1; @@ -480,6 +1256,11 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_2048_transAB) { int height_b = 2048; int width_b = 768; + int batch_res = batch; + int channel_res = channel; + int height_res = width; + int width_res = height_b; + bool transA = true; bool transB = true; @@ -517,6 +1298,8 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_2048_transAB) { nntrainer::Tensor C = dotCl(A_fp32, B_fp32, transA, transB); nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + // printf("if N == 1 and transA & transB, Batch: %zu, Channel: %zu, Height: + // %zu, Width: %zu\n", C.batch(), C.channel(), C.height(), C.width()); float mseErrorNeon = mse(C.getData(), C_fp32.getData(), C.size());