Skip to content

Commit

Permalink
[ GPU/OpenCL ] change rmsnorm_layer_cl to inherit LayerImplCl
Browse files Browse the repository at this point in the history
- This commit updates rmsnorm_layer_cl.cpp/h to inherit LayerImplCl.
- This commit implements registerClKernels() of rmsnorm layer.
- This commit update cl_context.cpp (applying rmsnorm_layer_cl's update)
- This commit update common_properties.h (adding property for
rmsnormlayer)

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Eunju Yang <[email protected]>
  • Loading branch information
EunjuYang committed Dec 2, 2024
1 parent bbda540 commit 8702596
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 51 deletions.
8 changes: 5 additions & 3 deletions nntrainer/cl_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ static void add_default_object(ClContext &cc) {
ml::train::LayerType::LAYER_RESHAPE);
}

// @todo rmsnormlayercl also needs to be updated.
cc.registerFactory(nntrainer::createLayer<RMSNormLayerCl>,
RMSNormLayerCl::type, ml::train::LayerType::LAYER_RMSNORM);
if (RMSNormLayerCl::registerClKernels()) {
cc.registerFactory(nntrainer::createLayer<RMSNormLayerCl>,
RMSNormLayerCl::type,
ml::train::LayerType::LAYER_RMSNORM);
}

if (ConcatLayerCl::registerClKernels()) {
cc.registerFactory(nntrainer::createLayer<ConcatLayerCl>,
Expand Down
69 changes: 52 additions & 17 deletions nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,12 @@ static constexpr size_t SINGLE_INOUT_IDX = 0;

enum RMSParams { gamma };

RMSNormLayerCl::RMSNormLayerCl() : LayerImpl() { wt_idx.fill(0); }
RMSNormLayerCl::RMSNormLayerCl() : LayerImplCl() { wt_idx.fill(0); }

void RMSNormLayerCl::finalize(InitLayerContext &context) {
std::vector<TensorDim> dim = context.getInputDimensions();
context.setOutputDimensions(dim);
auto &rmsparams_gamma =
std::get<props::RMS_NORM_GAMMA_INIT_GPU>(rmsnorm_props);
auto &rmsparams_gamma = std::get<props::RMS_NORM_GAMMA_INIT>(rmsnorm_props);

TensorDim gamma_dim(
1, 1, 1, dim[0].width(),
Expand All @@ -123,9 +122,6 @@ void RMSNormLayerCl::forwarding(RunLayerContext &context, bool training) {
}
}

opencl::Kernel RMSNormLayerCl::kernel_rmsnorm;
opencl::Kernel RMSNormLayerCl::kernel_rmsnorm_fp16;

void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,
Tensor const &gamma, const float epsilon) {
bool ret = false;
Expand All @@ -138,11 +134,8 @@ void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,
int w = input.width();

do {
ClContext::SharedPtrClKernel kernel_rmsnorm_ptr =
cl_context_ref.registerClKernel(rmsnorm_cl_kernel_, "rmsnorm_cl");
if (!kernel_rmsnorm_ptr) {
break;
}

auto kernel_rmsnorm_ptr = layer_kernel_ptrs[Kernels::RMSNORM_CL];

opencl::Buffer inputbuf(cl_context_ref.context_inst_, dim1 * sizeof(float),
true, nullptr);
Expand Down Expand Up @@ -219,6 +212,7 @@ void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,
} while (false);
}

#ifdef ENABLE_FP16
void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
Tensor const &gamma,
const float epsilon) {
Expand All @@ -232,12 +226,8 @@ void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
int h = input.height();
int w = input.width();
do {
ClContext::SharedPtrClKernel kernel_rmsnorm_ptr =
cl_context_ref.registerClKernel(rmsnorm_cl_kernel_fp16_,
"rmsnorm_cl_fp16");
if (!kernel_rmsnorm_ptr) {
break;
}
auto kernel_rmsnorm_ptr = layer_kernel_ptrs[Kernels::RMSNORM_CL_FP16];

opencl::Buffer inputbuf(cl_context_ref.context_inst_,
dim1 * sizeof(cl_half), true, nullptr);

Expand Down Expand Up @@ -308,6 +298,7 @@ void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
}
} while (false);
}
#endif

void RMSNormLayerCl::incremental_forwarding(nntrainer::RunLayerContext &context,
unsigned int from, unsigned int to,
Expand Down Expand Up @@ -339,7 +330,11 @@ void RMSNormLayerCl::incremental_forwarding(nntrainer::RunLayerContext &context,
if (in_step.getDataType() == ml::train::TensorDim::DataType::FP32) {
rmsnormProcess(in, out, gamma, epsilon);
} else {
#ifdef ENABLE_FP16
rmsnormProcess_fp16(in, out, gamma, epsilon);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
}
}

Expand All @@ -362,4 +357,44 @@ void RMSNormLayerCl::setProperty(const std::vector<std::string> &values) {
LayerImpl::setProperty(remain_props);
}

bool RMSNormLayerCl::registerClKernels() {

// check if already registered
if (!layer_kernel_ptrs.empty()) {
ml_loge("kernels for concat layer are already registered.");
return false;
}

do {

ClContext::SharedPtrClKernel kernel_rmsnorm_ptr = nullptr;

kernel_rmsnorm_ptr =
cl_context_ref.registerClKernel(rmsnorm_cl_kernel_, "rmsnorm_cl");
if (!kernel_rmsnorm_ptr) {
ml_loge("OpenCL Error: Fail to register rmsnorm_cl kernel");
break;
}
layer_kernel_ptrs.emplace_back(kernel_rmsnorm_ptr);

#ifdef ENABLE_FP16
kernel_rmsnorm_ptr = cl_context_ref.registerClKernel(
rmsnorm_cl_kernel_fp16_, "rmsnorm_cl_fp16");
if (!kernel_rmsnorm_ptr) {
ml_loge("OpenCL Error: Fail to register rmsnorm_cl_fp16 kernel");
break;
}
layer_kernel_ptrs.emplace_back(kernel_rmsnorm_ptr);
#endif

return true;

} while (false);

// clear all registered kernels if any error occurs during registration
layer_kernel_ptrs.clear();

return false;
}

} // namespace nntrainer
48 changes: 17 additions & 31 deletions nntrainer/layers/cl_layers/rmsnorm_layer_cl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#ifdef __cplusplus

#include <common_properties.h>
#include <layer_impl.h>
#include <layer_impl_cl.h>
#include <nntrainer_log.h>

#include <cl_context.h>
Expand All @@ -25,36 +25,11 @@

namespace nntrainer {

namespace props {

/**
* @brief RMS_NORM_GAMMA_INIT_GPU Initialization Enumeration Information
*
*/
class RMS_NORM_GAMMA_INIT_GPU final
: public ::nntrainer::EnumProperty<::nntrainer::props::InitializerInfo> {
public:
/**
* @brief Construct a RMS_NORM_GAMMA_INIT object
*/
RMS_NORM_GAMMA_INIT_GPU(
::nntrainer::Initializer value = ::nntrainer::Initializer::ONES) {
set(value);
};
using prop_tag = enum_class_prop_tag;
static constexpr const char *key = "gamma_initializer";
};
}; // namespace props

/**
* @class RMSNormLayer
* @brief RMS Norm layer
*/

class RMSNormLayerCl : public LayerImpl {

private:
inline static ClContext cl_context_ref;
class RMSNormLayerCl : public LayerImplCl {

public:
/**
Expand Down Expand Up @@ -118,9 +93,6 @@ class RMSNormLayerCl : public LayerImpl {
*/
const std::string getType() const override { return RMSNormLayerCl::type; };

static opencl::Kernel kernel_rmsnorm;
static opencl::Kernel kernel_rmsnorm_fp16;

/**
* @brief Process data and dimensions for rms norm operation
* @param[in] input Tensor
Expand Down Expand Up @@ -153,12 +125,26 @@ class RMSNormLayerCl : public LayerImpl {
*/
void setProperty(const std::vector<std::string> &values) override;

/**
* @brief registerClKernels
*/
static bool registerClKernels();

inline static const std::string type = "rmsnorm";

private:
std::array<unsigned int, 1> wt_idx;
std::tuple<props::RMS_NORM_GAMMA_INIT_GPU, props::Epsilon>

std::tuple<props::RMS_NORM_GAMMA_INIT, props::Epsilon>
rmsnorm_props; /**< rmsnorm layer properties */

inline static std::vector<ClContext::SharedPtrClKernel>
layer_kernel_ptrs; /**< kernel list relevant with this layer */

enum Kernels {
RMSNORM_CL,
RMSNORM_CL_FP16,
};
};
} // namespace nntrainer

Expand Down
13 changes: 13 additions & 0 deletions nntrainer/layers/common_properties.h
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,19 @@ class BNPARAMS_BETA_INIT final : public EnumProperty<InitializerInfo> {
static constexpr const char *key = "beta_initializer";
};

/**
* @brief RMS_NORM_GAMMA_INIT Initialization Enumeration Information
*/
class RMS_NORM_GAMMA_INIT final : public EnumProperty<InitializerInfo> {
public:
/**
* @brief Construct a RMS_NORM_GAMMA_INIT object
*/
RMS_NORM_GAMMA_INIT(Initializer value = Initializer::ONES) { set(value); };
using prop_tag = enum_class_prop_tag;
static constexpr const char *key = "gamma_initializer";
};

/**
* @brief Enumeration of tensor regularization type
*/
Expand Down
1 change: 1 addition & 0 deletions test/jni/Android.mk
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ LOCAL_SRC_FILES := \
../unittest/layers/unittest_layers_impl.cpp \
../unittest/layers/unittest_layers_transpose_cl.cpp \
../unittest/layers/unittest_layers_concat_cl.cpp \
../unittest/layers/unittest_layers_swiglu_cl.cpp \
../unittest/layers/unittest_layers_fully_connected_cl.cpp \
../unittest/layers/unittest_layers_input.cpp \
../unittest/layers/unittest_layers_loss.cpp \
Expand Down

0 comments on commit 8702596

Please sign in to comment.