From 75f76e97f95d5a80959920a69e9d8060adbefe80 Mon Sep 17 00:00:00 2001 From: ThummalaPallavi Date: Tue, 11 Jun 2024 14:54:49 +0530 Subject: [PATCH] [GPU/OpenCL] Initial version of RMSNorm Layer Added naive version of OpenCL implementation for RMSNorm Layer. Incorporated kernel for ops used. Added unit test for rmsnorm_layer_cl. Signed-off-by: ThummalaPallavi --- api/ccapi/include/layer.h | 10 + api/nntrainer-api-common.h | 1 + nntrainer/cl_context.cpp | 4 +- nntrainer/layers/cl_layers/meson.build | 1 + .../layers/cl_layers/rmsnorm_layer_cl.cpp | 398 ++++++++++++++++++ nntrainer/layers/cl_layers/rmsnorm_layer_cl.h | 176 ++++++++ nntrainer/layers/layer_context.cpp | 4 + nntrainer/layers/layer_context.h | 2 + test/input_gen/gen_layer_tests.py | 32 ++ test/jni/Android.mk | 1 + .../layers/unittest_layers_rmsnorm_cl.cpp | 51 +++ 11 files changed, 679 insertions(+), 1 deletion(-) create mode 100644 nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp create mode 100644 nntrainer/layers/cl_layers/rmsnorm_layer_cl.h create mode 100644 test/unittest/layers/unittest_layers_rmsnorm_cl.cpp diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index ca0ae19f62..4ef7db08d6 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -98,6 +98,7 @@ enum LayerType { LAYER_REDUCE_MEAN, /**< Reduce mean Layer type */ LAYER_LOSS_CONSTANT_DERIVATIVE, /**< Synthetic loss layer to feed constant derivative */ + LAYER_RMSNORM = ML_TRAIN_LAYER_TYPE_RMSNORM, /** FullyConnected( return createLayer(LayerType::LAYER_FC, properties, compute_engine); } +/** + * @brief Helper function to create RMS normalization layer for GPU + */ +inline std::unique_ptr RMSNormCl( + const std::vector &properties = {}, + const LayerComputeEngine &compute_engine = LayerComputeEngine::GPU) { + return createLayer(LayerType::LAYER_RMSNORM, properties, compute_engine); +} + /** * @brief Helper function to create batch normalization layer */ diff --git a/api/nntrainer-api-common.h b/api/nntrainer-api-common.h index b37a3a750d..3f6b597a70 100644 --- a/api/nntrainer-api-common.h +++ b/api/nntrainer-api-common.h @@ -75,6 +75,7 @@ typedef enum { Sigmoid Loss Layer type (Since 6.5) */ ML_TRAIN_LAYER_TYPE_LOSS_CROSS_ENTROPY_SOFTMAX = 502, /**< Cross Entropy with Softmax Loss Layer type (Since 6.5) */ + ML_TRAIN_LAYER_TYPE_RMSNORM = 503, /**< Cross Entropy with */ ML_TRAIN_LAYER_TYPE_UNKNOWN = 999 /**< Unknown Layer */ } ml_train_layer_type_e; diff --git a/nntrainer/cl_context.cpp b/nntrainer/cl_context.cpp index be7345eed0..83db98848a 100644 --- a/nntrainer/cl_context.cpp +++ b/nntrainer/cl_context.cpp @@ -14,7 +14,7 @@ #include #include - +#include namespace nntrainer { std::mutex cl_factory_mutex; @@ -26,6 +26,8 @@ static void add_default_object(ClContext &cc) { cc.registerFactory(nntrainer::createLayer, FullyConnectedLayerCl::type, ml::train::LayerType::LAYER_FC); + cc.registerFactory(nntrainer::createLayer, + RMSNormLayerCl::type, ml::train::LayerType::LAYER_RMSNORM); } static void registerer(ClContext &cc) noexcept { diff --git a/nntrainer/layers/cl_layers/meson.build b/nntrainer/layers/cl_layers/meson.build index fd8ed3cae9..4b5f902afa 100644 --- a/nntrainer/layers/cl_layers/meson.build +++ b/nntrainer/layers/cl_layers/meson.build @@ -1,6 +1,7 @@ cl_layer_sources = [ 'fc_layer_cl.cpp', 'blas_kernels.cpp', + 'rmsnorm_layer_cl.cpp' ] if get_option('enable-fp16') diff --git a/nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp b/nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp new file mode 100644 index 0000000000..3a8a32c441 --- /dev/null +++ b/nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp @@ -0,0 +1,398 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Thummala Pallavi + * + * @file rmsnorm_layer_cl.cpp + * @date 8 June 2024 + * @brief This is Fully Connected Layer Class for Neural Network with OpenCl + * implementation + * @see https://github.com/nnstreamer/nntrainer + * @author Thummala Pallavi + * @bug No known bugs except for NYI items + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +std::string rmsnorm_cl_kernel_fp16_ = + R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + __kernel void rmsnorm_cl_fp16( + __global const half *input, // Input tensor + __global half *output, // Output tensor + __global const half *alpha, // Alpha values (one for each channel) + half epsilon, + int B, // Number of batches + int C, // Number of channels + int H, // Height of feature map + int W // Width of feature map +) { + int global_id = get_global_id(0); // Get the global work item index + + // Compute the corresponding batch, height, and channel indices + int n = global_id / (H * C); // Batch index + int hc = global_id % (H * C); // Combined height and channel index + int h = hc / C; // Height index + int c = hc % C; // Channel index + int index = ((n * C + c) * H + h) * W; + + // Calculate RMS norm for the current channel, height, and batch + half sum_squares = 0.0f; + for (int j = 0; j < W; ++j) { + sum_squares += input[index+j] * input[index+j]; + } + sum_squares /= W; + half rms_norm = sqrt(sum_squares + epsilon); + // Each work item processes all width elements for its specific n, h, c + for (int w = 0; w < W; ++w) { + output[index+w] = (input[index+w] / rms_norm) * alpha[c]; + } +} +)"; + +std::string rmsnorm_cl_kernel_ = + R"(__kernel void rmsnorm_cl( + __global const float *input, // Input tensor + __global float *output, // Output tensor + __global const float *alpha, // Alpha values (one for each channel) + float epsilon, + int B, // Number of batches + int C, // Number of channels + int H, // Height of feature map + int W // Width of feature map +) { + int global_id = get_global_id(0); // Get the global work item index + + // Compute the corresponding batch, height, and channel indices + int n = global_id / (H * C); // Batch index + int hc = global_id % (H * C); // Combined height and channel index + int h = hc / C; // Height index + int c = hc % C; // Channel index + int index = ((n * C + c) * H + h) * W; + + // Calculate RMS norm for the current channel, height, and batch + float sum_squares = 0.0f; + for (int j = 0; j < W; ++j) { + sum_squares += input[index+j] * input[index+j]; + } + sum_squares /= W; + float rms_norm = sqrt(sum_squares + epsilon); + // Each work item processes all width elements for its specific n, h, c + for (int w = 0; w < W; ++w) { + output[index+w] = (input[index+w] / rms_norm) * alpha[c]; + } +} +)"; + +namespace nntrainer { + +static constexpr size_t SINGLE_INOUT_IDX = 0; + +enum RMSParams { gamma }; + +RMSNormLayerCl::RMSNormLayerCl() : + LayerImpl() { + wt_idx.fill(0); +} + +void RMSNormLayerCl::finalize(InitLayerContext &context) { + std::vector dim = context.getInputDimensions(); + context.setOutputDimensions(dim); + auto &rmsparams_gamma = std::get(rmsnorm_props); + + TensorDim gamma_dim( + 1, 1, 1, dim[0].width(), + TensorDim::TensorType(context.getFormat(), + context.getWeightDataType())); + wt_idx[RMSParams::gamma] = context.requestWeight( + gamma_dim, rmsparams_gamma, WeightRegularizer::NONE, 1.0f, 0.0f, + "gamma", false); +} + + +void RMSNormLayerCl::forwarding(RunLayerContext &context, + bool training) { + Tensor &in = context.getInput(SINGLE_INOUT_IDX); + Tensor &out = context.getOutput(SINGLE_INOUT_IDX); + Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]); + + auto &epsilon = std::get(rmsnorm_props).get(); + + if (in.getDataType() == ml::train::TensorDim::DataType::FP32) { + rmsnormProcess(in,out,gamma,epsilon,context); + } + else{ + rmsnormProcess_fp16(in,out,gamma,epsilon,context); + } +} + +opencl::Kernel RMSNormLayerCl::kernel_rmsnorm; +opencl::Kernel RMSNormLayerCl::kernel_rmsnorm_fp16; + + +void RMSNormLayerCl::rmsnormProcess(Tensor const &input, + Tensor &result,Tensor const &gamma,const float epsilon, + RunLayerContext &context){ + + + bool ret = false; + int dim1 = input.batch() * input.height() * input.width()* input.channel(); + CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), input.height(), + input.width(), input.getTensorType()); + int b=input.batch(); + int c = input.channel(); + int h = input.height(); + int w = input.width(); + do { + ret = + context.clCreateKernel(rmsnorm_cl_kernel_, context.LayerKernel::RMSNORM, + RMSNormLayerCl::kernel_rmsnorm); + if (!ret) { + break; + } + opencl::Buffer inputbuf(context.context_inst_, dim1 * sizeof(float), true, + nullptr); + + opencl::Buffer gammabuf(context.context_inst_, input.width() * sizeof(float), true, + nullptr); + opencl::Buffer resultbuf(context.context_inst_, dim1 * sizeof(float), true, nullptr); + + const float *data = input.getData(); + float *rdata = result.getData(); + const float *gdata = gamma.getData(); + ret = inputbuf.WriteData(context.command_queue_inst_, data); + if (!ret) { + break; + } + + ret = gammabuf.WriteData(context.command_queue_inst_, gdata); + if (!ret) { + break; + } + ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments( + 0, &inputbuf, sizeof(cl_mem)); + if (!ret) { + break; + } + + ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments( + 1, &resultbuf, sizeof(cl_mem)); + if (!ret) { + break; + } + + + ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments( + 2, &gammabuf, sizeof(cl_mem)); + if (!ret) { + break; + } + ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments( + 4, &b, sizeof(int)); + if (!ret) { + break; + } + + ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments( + 3, &epsilon, sizeof(float)); + if (!ret) { + break; + } + + ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments( + 5, &c, sizeof(int)); + if (!ret) { + break; + } + ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments( + 6, &h, sizeof(int)); + if (!ret) { + break; + } + ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments( + 7, &w, sizeof(int)); + if (!ret) { + break; + } + const int work_groups_count[3] = {b*h*c, 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + ret = context.command_queue_inst_.DispatchCommand( + RMSNormLayerCl::kernel_rmsnorm, work_groups_count, work_group_size); + if (!ret) { + break; + } + + ret = resultbuf.ReadData(context.command_queue_inst_, rdata); + if (!ret) { + break; + } + + } while (false); + +} + +void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, + Tensor &result,Tensor const &gamma,const float epsilon, + RunLayerContext &context){ + + + bool ret = false; + int dim1 = input.batch() * input.height() * input.width()* input.channel(); + CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), input.height(), + input.width(), input.getTensorType()); + int b=input.batch(); + int c = input.channel(); + int h = input.height(); + int w = input.width(); + do { + ret = + context.clCreateKernel(rmsnorm_cl_kernel_fp16_, context.LayerKernel::RMSNORM_FP16, + RMSNormLayerCl::kernel_rmsnorm_fp16); + if (!ret) { + break; + } + opencl::Buffer inputbuf(context.context_inst_, dim1 * sizeof(cl_half), true, + nullptr); + + opencl::Buffer gammabuf(context.context_inst_, input.width() * sizeof(cl_half), true, + nullptr); + opencl::Buffer resultbuf(context.context_inst_, dim1 * sizeof(cl_half), true, nullptr); + + const __fp16 *data = input.getData<__fp16>(); + __fp16 *rdata = result.getData<__fp16>(); + const __fp16 *gdata = gamma.getData<__fp16>(); + ret = inputbuf.WriteData(context.command_queue_inst_, data); + if (!ret) { + break; + } + + ret = gammabuf.WriteData(context.command_queue_inst_, gdata); + if (!ret) { + break; + } + ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( + 0, &inputbuf, sizeof(cl_mem)); + if (!ret) { + break; + } + ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( + 1, &resultbuf, sizeof(cl_mem)); + if (!ret) { + break; + } + + ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( + 2, &gammabuf, sizeof(cl_mem)); + if (!ret) { + break; + } + ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( + 4, &b, sizeof(int)); + if (!ret) { + break; + } + + ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( + 3, &epsilon, sizeof(cl_half)); + if (!ret) { + break; + } + + ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( + 5, &c, sizeof(int)); + if (!ret) { + break; + } + ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( + 6, &h, sizeof(int)); + if (!ret) { + break; + } + ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( + 7, &w, sizeof(int)); + if (!ret) { + break; + } + const int work_groups_count[3] = {(b*h*c), 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + ret = context.command_queue_inst_.DispatchCommand( + RMSNormLayerCl::kernel_rmsnorm_fp16, work_groups_count, work_group_size); + if (!ret) { + break; + } + + ret = resultbuf.ReadData(context.command_queue_inst_, rdata); + if (!ret) { + break; + } + + } while (false); + +} + +void RMSNormLayerCl::incremental_forwarding(nntrainer::RunLayerContext &context, + unsigned int from, unsigned int to, + bool training) { + Tensor &in = context.getInput(SINGLE_INOUT_IDX); + Tensor &out = context.getOutput(SINGLE_INOUT_IDX); + Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]); + ml::train::TensorDim in_dim = in.getDim(); + ml::train::TensorDim out_dim = out.getDim(); + + ml::train::TensorDim in_step_dim = in_dim; + ml::train::TensorDim out_step_dim = out_dim; + + if (from) { + NNTR_THROW_IF(to - from != 1, std::invalid_argument) + << "incremental step size is not 1"; + from = 0; + to = 1; + } + + in_step_dim.height(to - from); + out_step_dim.height(to - from); + + Tensor in_step = in.getSharedDataTensor(in_step_dim, 0, true); + Tensor out_step = out.getSharedDataTensor(out_step_dim, 0, true); + + auto &epsilon = std::get(rmsnorm_props).get(); + + if (in_step.getDataType() == ml::train::TensorDim::DataType::FP32) { + rmsnormProcess(in,out,gamma,epsilon,context); + } + else{ + rmsnormProcess_fp16(in,out,gamma,epsilon,context); + } +} + +void RMSNormLayerCl::calcDerivative(RunLayerContext &context) { + ml_logi("Training not supported"); +} + +void RMSNormLayerCl::calcGradient(RunLayerContext &context) { + ml_logi("Training not supported"); +} + +void RMSNormLayerCl::exportTo( + Exporter &exporter, const ml::train::ExportMethods &method) const { + LayerImpl::exportTo(exporter, method); + exporter.saveResult(rmsnorm_props, method, this); +} + +void RMSNormLayerCl::setProperty( + const std::vector &values) { + auto remain_props = loadProperties(values, rmsnorm_props); + LayerImpl::setProperty(remain_props); +} + +} + diff --git a/nntrainer/layers/cl_layers/rmsnorm_layer_cl.h b/nntrainer/layers/cl_layers/rmsnorm_layer_cl.h new file mode 100644 index 0000000000..335c29e41b --- /dev/null +++ b/nntrainer/layers/cl_layers/rmsnorm_layer_cl.h @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2020 + * + * @file rmsnorm_layer.h + * @date 8 June 2024 + * @brief This is RMS Norm Layer Class of Neural Network + * @see https://github.com/nnstreamer/nntrainer + * @author Thummala Pallavi + * @bug No known bugs except for NYI items + * + */ + +#ifndef __RMSNORM_LAYER_CL_H__ +#define __RMSNORM_LAYER_CL_H__ +#ifdef __cplusplus + +#include +#include +#include + +#include +#include + +#define CREATE_IF_EMPTY_DIMS(tensor, ...) \ + do { \ + if (tensor.empty()) \ + tensor = Tensor(__VA_ARGS__); \ + } while (0); + +namespace nntrainer { + +namespace props{ + +/** + * @brief RMS_NORM_GAMMA_INIT_GPU Initialization Enumeration Information + * + */ +class RMS_NORM_GAMMA_INIT_GPU final + : public ::nntrainer::EnumProperty<::nntrainer::props::InitializerInfo> { +public: + /** + * @brief Construct a RMS_NORM_GAMMA_INIT object + */ + RMS_NORM_GAMMA_INIT_GPU(::nntrainer::Tensor::Initializer value = + ::nntrainer::Tensor::Initializer::ONES) { + set(value); + }; + using prop_tag = enum_class_prop_tag; + static constexpr const char *key = "gamma_initializer"; +}; +}; + + +/** + * @class RMSNormLayer + * @brief RMS Norm layer + */ +class RMSNormLayerCl : public LayerImpl { +public: + /** + * @brief Constructor of RMS Norm Layer + */ + RMSNormLayerCl(); + + /** + * @brief Destructor of RMS Norm Layer + */ + ~RMSNormLayerCl() = default; + + /** + * @brief Move constructor. + * @param[in] RMSNorm && + */ + RMSNormLayerCl(RMSNormLayerCl &&rhs) noexcept = default; + + /** + * @brief Move assignment operator. + * @parma[in] rhs RMS Norm to be moved. + */ + RMSNormLayerCl &operator=(RMSNormLayerCl &&rhs) = default; + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(InitLayerContext &context) override; + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned + * int from, unsigned int to, bool training) + */ + void incremental_forwarding(RunLayerContext &context, unsigned int from, + unsigned int to, bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(RunLayerContext &context) override; + + /** + * @copydoc Layer::calcGradient(RunLayerContext &context) + */ + void calcGradient(RunLayerContext &context) override; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods + * method) + */ + void exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const override; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { + return RMSNormLayerCl::type; + }; + + static opencl::Kernel kernel_rmsnorm; + static opencl::Kernel kernel_rmsnorm_fp16; + + /** + * @brief Process data and dimensions for rms norm operation + * @param[in] input Tensor + * @param[in] result Tensor + * @param[in] gamma Tensor + * @param[in] epsilon float + * @param[in] RunLayerContext reference + */ + + + void rmsnormProcess(Tensor const &input, Tensor &result, Tensor const &gamma, const float epsilon, + RunLayerContext &context); + + + /** + * @brief Process data and dimensions for FP16 rms norm operation + * @param[in] input Tensor + * @param[in] result Tensor + * @param[in] gamma Tensor + * @param[in] epsilon float + * @param[in] RunLayerContext reference + */ + + + void rmsnormProcess_fp16(Tensor const &input, Tensor &result, Tensor const &gamma, const float epsilon, + RunLayerContext &context); + /** + * @copydoc Layer::supportBackwarding() + */ + bool supportBackwarding() const override { + return false; + } + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) override; + + inline static const std::string type = "rmsnorm"; + +private: + std::array wt_idx; + std::tuple + rmsnorm_props; /**< rmsnorm layer properties */ +}; +} // namespace nntrainer + +#endif /* __cplusplus */ +#endif /* __RMSNORM_LAYER_CL__ */ + diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index 1a66aed3cd..f00d0464fd 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -662,6 +662,10 @@ std::string RunLayerContext::getKernelName(LayerKernel layerKernel) { return "dot_cl_fp16"; case LayerKernel::SGEMM_FP16: return "sgemm_cl_fp16"; + case LayerKernel::RMSNORM: + return "rmsnorm_cl"; + case LayerKernel::RMSNORM_FP16: + return "rmsnorm_cl_fp16"; default: return ""; } diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index 43e9d8eaf8..e74176a869 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -836,6 +836,8 @@ class RunLayerContext { SGEMV_FP16 = 1 << 3, /**< placeholder for kernel name */ DOT_FP16 = 1 << 4, /**< placeholder for kernel name */ SGEMM_FP16 = 1 << 5, /**< placeholder for kernel name */ + RMSNORM = 1 << 6, + RMSNORM_FP16 = 1 << 7 }; /** diff --git a/test/input_gen/gen_layer_tests.py b/test/input_gen/gen_layer_tests.py index 7a1ed18ec6..16ab0697f5 100644 --- a/test/input_gen/gen_layer_tests.py +++ b/test/input_gen/gen_layer_tests.py @@ -883,3 +883,35 @@ def swiglu(inputs): "swiglu", input_type="float", ) + + class RMSNorm(tf.keras.layers.Layer): + def __init__(self, epsilon=1e-3, **kwargs): + super(RMSNorm, self).__init__(**kwargs) + self.epsilon = epsilon + + def build(self, input_shape): + # Initialize gamma as trainable parameters + self.gamma = self.add_weight( + shape=input_shape[-1:], + initializer=tf.keras.initializers.Ones(), + trainable=False, + name='gamma' + ) + super(RMSNorm, self).build(input_shape) + + def call(self, inputs): + # Compute the mean of the squares of the inputs along the last dimension + mean_square = tf.reduce_mean(tf.square(inputs), axis=[-1], keepdims=True) + print(mean_square) + # Compute the RMS value with epsilon for numerical stability + rms_value = tf.sqrt(mean_square + self.epsilon) + print(rms_value) + # Normalize inputs and scale by gamma + normalized_inputs = inputs / rms_value * self.gamma + return normalized_inputs + + rms_normtest = RMSNorm() + rms_normtest_fp16 = RMSNorm() + record_single(rms_normtest,(2,3,3,3),"rms_normtest") + record_single_fp16(rms_normtest_fp16,(2,3,3,3),"rms_normtest_fp16_new") + diff --git a/test/jni/Android.mk b/test/jni/Android.mk index 978e98bd67..83bc66e213 100644 --- a/test/jni/Android.mk +++ b/test/jni/Android.mk @@ -445,6 +445,7 @@ LOCAL_SRC_FILES := \ ../unittest/layers/unittest_layers_loss.cpp \ ../unittest/layers/unittest_layers_fully_connected_cl.cpp \ ../unittest/layers/unittest_layers_fully_connected.cpp \ + ../unittest/layers/unittest_layers_rmsnorm_cl.cpp \ ../unittest/layers/unittest_layers_batch_normalization.cpp \ ../unittest/layers/unittest_layers_layer_normalization.cpp \ ../unittest/layers/unittest_layers_convolution2d.cpp \ diff --git a/test/unittest/layers/unittest_layers_rmsnorm_cl.cpp b/test/unittest/layers/unittest_layers_rmsnorm_cl.cpp new file mode 100644 index 0000000000..1efe8a62b5 --- /dev/null +++ b/test/unittest/layers/unittest_layers_rmsnorm_cl.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Thummala Pallavi + * + * @file unittest_layers_fully_connected_cl.cpp + * @date 7 June 2024 + * @brief Fully Connected Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author Thummala Pallavi + * @bug No known bugs except for NYI items + */ +#include + +#include + +#include +#include + +auto semantic_rms = LayerSemanticsParamType( + nntrainer::createLayer, + nntrainer::RMSNormLayerCl::type, {"epsilon=0.001"}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1); + +GTEST_PARAMETER_TEST(RMSNormGPU, LayerSemanticsGpu, + ::testing::Values(semantic_rms)); + + +auto rms_plain_skip_CG = LayerGoldenTestParamType( + nntrainer::createLayer, {"epsilon=0.001"}, + "2:3:3:3", "rms_normtest.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD | + LayerGoldenTestParamOptions::USE_INC_FORWARD, + "nchw", "fp32", "fp32"); + +GTEST_PARAMETER_TEST(RMSNormGPU, LayerGoldenTest, + ::testing::Values(rms_plain_skip_CG)); + +#ifdef ENABLE_FP16 +auto rms_plain_skip_CG_fp16 = LayerGoldenTestParamType( + nntrainer::createLayer, {"epsilon=0.001"}, + "2:3:3:3", "rms_normtest_fp16_new.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD | + LayerGoldenTestParamOptions::USE_INC_FORWARD, + "nchw", "fp16", "fp16"); + +GTEST_PARAMETER_TEST(RMSNormGPU16, LayerGoldenTest, + ::testing::Values(rms_plain_skip_CG_fp16)); + +#endif