Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU/OpenCL] Fused DotCL, Addition and RMS for optimization #2831

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion api/ccapi/include/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ enum LayerType {
LAYER_UPSAMPLE2D, /**< Upsample 2D Layer type */
LAYER_RMSNORM = ML_TRAIN_LAYER_TYPE_RMSNORM, /**<RMS NORM Layer */
LAYER_TRANSPOSE = ML_TRAIN_LAYER_TYPE_TRANSPOSE, /**< Transpose Layer type */
LAYER_UNKNOWN = ML_TRAIN_LAYER_TYPE_UNKNOWN /**< Unknown */
LAYER_FUSED_FC_RMS =
ML_TRAIN_LAYER_TYPE_FUSED_FC_RMS, /**< Fused FC and RMS layer*/
LAYER_UNKNOWN = ML_TRAIN_LAYER_TYPE_UNKNOWN /**< Unknown */
};

/**
Expand Down
17 changes: 9 additions & 8 deletions api/nntrainer-api-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,15 @@ typedef enum {
27, /**< Layer Normalization Layer type (Since 7.0) */
ML_TRAIN_LAYER_TYPE_POSITIONAL_ENCODING =
28, /**< Positional Encoding Layer type (Since 7.0) */
ML_TRAIN_LAYER_TYPE_IDENTITY = 29, /**< Identity Layer type (Since 8.0) */
ML_TRAIN_LAYER_TYPE_SWIGLU = 30, /**< Swiglu Layer type */
ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_ADD = 32, /**< Add Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_SUBTRACT = 33, /**< Subtract Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_MULTIPLY = 34, /**< Multiply Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_DIVIDE = 35, /**< Divide Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_TRANSPOSE = 36, /**< Transpose Layer type */
ML_TRAIN_LAYER_TYPE_IDENTITY = 29, /**< Identity Layer type (Since 8.0) */
ML_TRAIN_LAYER_TYPE_SWIGLU = 30, /**< Swiglu Layer type */
ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_ADD = 32, /**< Add Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_SUBTRACT = 33, /**< Subtract Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_MULTIPLY = 34, /**< Multiply Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_DIVIDE = 35, /**< Divide Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_TRANSPOSE = 36, /**< Transpose Layer type */
ML_TRAIN_LAYER_TYPE_FUSED_FC_RMS = 37, /**< Fused fc and RMS layer type */
ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP =
300, /**< Preprocess flip Layer (Since 6.5) */
ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE =
Expand Down
8 changes: 8 additions & 0 deletions nntrainer/cl_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cl_context.h>
#include <concat_cl.h>
#include <fc_layer_cl.h>
#include <fused_fc_norm_cl.h>
#include <reshape_cl.h>
#include <rmsnorm_layer_cl.h>
#include <swiglu_cl.h>
Expand Down Expand Up @@ -56,6 +57,10 @@ static void add_default_object(ClContext &cc) {
cc.registerFactory(nntrainer::createLayer<TransposeLayerCl>,
TransposeLayerCl::type,
ml::train::LayerType::LAYER_TRANSPOSE);

cc.registerFactory(nntrainer::createLayer<FullyConnectedRMSNormLayerCl>,
FullyConnectedRMSNormLayerCl::type,
ml::train::LayerType::LAYER_FUSED_FC_RMS);
}

static void registerer(ClContext &cc) noexcept {
Expand Down Expand Up @@ -143,6 +148,7 @@ void ClContext::initBlasClKernels() {
registerClKernel(sgemm_cl_transAB_kernel_, "sgemm_cl_transAB");
registerClKernel(addition_cl_kernel_, "addition_cl");
registerClKernel(sscal_cl_kernel_, "sscal_cl");
registerClKernel(rmsnorm_cl_kernel_new, "rmsnorm_cl");

#ifdef ENABLE_FP16
registerClKernel(sgemv_cl_kernel_fp16_, "sgemv_cl_fp16");
Expand All @@ -154,6 +160,8 @@ void ClContext::initBlasClKernels() {
registerClKernel(sgemm_cl_transAB_kernel_fp16_, "sgemm_cl_transAB_fp16");
registerClKernel(addition_cl_kernel_fp16_, "addition_cl_fp16");
registerClKernel(sscal_cl_kernel_fp16_, "sscal_cl_fp16");
registerClKernel(rmsnorm_cl_kernel_fp16_new, "rmsnorm_cl_fp16");

#endif
blas_kernels_initialized = true;
}
Expand Down
243 changes: 243 additions & 0 deletions nntrainer/layers/cl_layers/fused_fc_norm_cl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 Debadri Samaddar <[email protected]>
*
* @file fused_fc_norm_cl.cpp
* @date 7 May 2024
* @brief This is Fully Connected Layer Class for Neural Network with OpenCl
* implementation
* @see https://github.com/nnstreamer/nntrainer
* @author Debadri Samaddar <[email protected]>
* @bug No known bugs except for NYI items
*
*/

#include <blas_kernel_interface.h>
#include <common_properties.h>
#include <fused_fc_norm_cl.h>
#include <layer_context.h>
#include <lazy_tensor.h>
#include <nntrainer_error.h>
#include <nntrainer_log.h>
#include <node_exporter.h>
#include <util_func.h>

namespace nntrainer {

static constexpr size_t SINGLE_INOUT_IDX = 0;

enum FC_RMSParams { weight, bias, gamma };
// enum FCParams { weight, bias };
// enum RMSParams { gamma };

// , fc_rms_props(props::Unit(), props::FUSED_FC_RMS_NORM_GAMMA_INIT_GPU(),
// props::Epsilon())

FullyConnectedRMSNormLayerCl::FullyConnectedRMSNormLayerCl() :
LayerImpl(),
fc_rms_props(props::Unit(), props::FUSED_FC_RMS_NORM_GAMMA_INIT_GPU(),
props::Epsilon()) {
weight_idx.fill(std::numeric_limits<unsigned>::max());
}

void FullyConnectedRMSNormLayerCl::finalize(InitLayerContext &context) {
auto &weight_regularizer =
std::get<props::WeightRegularizer>(*layer_impl_props);
auto &weight_regularizer_constant =
std::get<props::WeightRegularizerConstant>(*layer_impl_props);
auto &weight_initializer =
std::get<props::WeightInitializer>(*layer_impl_props);
auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
auto &bias_initializer = std::get<props::BiasInitializer>(*layer_impl_props);
auto &disable_bias = std::get<props::DisableBias>(*layer_impl_props);

auto unit = std::get<props::Unit>(fc_rms_props).get();

NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument)
<< "Fully connected layer takes only one input";

std::vector<TensorDim> output_dims(1);

/// @todo fc actaully supports multidimensions. EffDimFlag shouldn't be fixed
/// like this.
context.setEffDimFlagInputDimension(0, 0b1001);
context.setDynDimFlagInputDimension(0, 0b1000);

bool is_nchw = (context.getFormat() == Tformat::NCHW);
/** set output dimensions */
auto const &in_dim = context.getInputDimensions()[0];
output_dims[0] = in_dim;
is_nchw ? output_dims[0].width(unit) : output_dims[0].channel(unit);

output_dims[0].setTensorType(
{context.getFormat(), context.getActivationDataType()});

context.setOutputDimensions(output_dims);

/** set weight specifications */
// @todo : This NCHW format setting is just temporal, it needs to be set by
// global configuration
TensorDim bias_dim(
1, is_nchw ? 1 : unit, 1, is_nchw ? unit : 1,
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
is_nchw ? 0b0001 : 0b0100);

TensorDim weight_dim(
1, is_nchw ? 1 : unit, is_nchw ? in_dim.width() : 1,
is_nchw ? unit : in_dim.channel(),
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
is_nchw ? 0b0011 : 0b0101);

weight_idx[FC_RMSParams::weight] = context.requestWeight(
weight_dim, weight_initializer, weight_regularizer,
weight_regularizer_constant, weight_decay, "weight", true);

if (disable_bias.empty() || disable_bias.get() == false) {
weight_idx[FC_RMSParams::bias] =
context.requestWeight(bias_dim, bias_initializer, WeightRegularizer::NONE,
1.0f, bias_decay, "bias", true);
}

// for RMS layer, size of output already set for fc, line 70
auto &rmsparams_gamma =
std::get<props::FUSED_FC_RMS_NORM_GAMMA_INIT_GPU>(fc_rms_props);

TensorDim gamma_dim(
1, 1, 1, output_dims[0].width(),
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()));
weight_idx[FC_RMSParams::gamma] =
context.requestWeight(gamma_dim, rmsparams_gamma, WeightRegularizer::NONE,
1.0f, 0.0f, "gamma", false);
}

// TO-DO
/////////////////////////////////////////////////////////////////////////
// fc
void FullyConnectedRMSNormLayerCl::exportTo(
Exporter &exporter, const ml::train::ExportMethods &method) const {
LayerImpl::exportTo(exporter, method);
exporter.saveResult(fc_rms_props, method, this);
}

void FullyConnectedRMSNormLayerCl::setProperty(
const std::vector<std::string> &values) {
auto remain_props = loadProperties(values, fc_rms_props);
LayerImpl::setProperty(remain_props);
}

void FullyConnectedRMSNormLayerCl::forwarding(RunLayerContext &context,
bool training) {

// for fc layer
Tensor &weight = context.getWeight(weight_idx[FC_RMSParams::weight]);
Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);

// for rms
Tensor &gamma = context.getWeight(weight_idx[FC_RMSParams::gamma]);
auto &epsilon = std::get<props::Epsilon>(fc_rms_props).get();

auto disable_bias = std::get<props::DisableBias>(*layer_impl_props);
bool disable_bias_value = disable_bias.empty() || disable_bias.get() == false;
const Tensor &bias = context.getWeight(weight_idx[FC_RMSParams::bias]);
// printf("\n*************************************************************************************************************************************\n");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need these comments?

// printf("Bias value : %s\n", disable_bias_value ? "true" : "false");
// printf("\nInput Tensor Batch: %u, Channel: %u, Height: %u, Width: %u\n",
// input_.batch(), input_.channel(), input_.height(), input_.width()); for
// (unsigned int i = 0; i < input_.size(); ++i) {
// printf("Element %u -> %f\n", i, *(input_.getData<float>() + i));
// }
// printf("\nWeight Tensor Batch: %u, Channel: %u, Height: %u, Width: %u\n",
// weight.batch(), weight.channel(), weight.height(), weight.width());
// printf("\nHidden Tensor Batch: %u, Channel: %u, Height: %u, Width: %u\n",
// hidden_.batch(), hidden_.channel(), hidden_.height(), hidden_.width());
// printf("\nGamma Tensor Batch: %u, Channel: %u, Height: %u, Width: %u\n",
// gamma.batch(), gamma.channel(), gamma.height(), gamma.width());
// printf("\nEpsilon value : %f\n", epsilon);
// printf("\n-----------------------------------------starting with fusion
// process from layer side-----------------------------------------------\n");

fusedProcess(input_, weight, hidden_, bias, disable_bias_value, gamma,
epsilon);
}

// TO-DO
////// need to implement the incremental forwarding
void FullyConnectedRMSNormLayerCl::incremental_forwarding(
RunLayerContext &context, unsigned int from, unsigned int to, bool training) {
Tensor w;
Tensor &weight = w;
context.getWeight(weight, weight_idx[FC_RMSParams::weight]);

Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);

// rms
Tensor &gamma = context.getWeight(weight_idx[FC_RMSParams::gamma]);

TensorDim input_dim = input_.getDim();
TensorDim hidden_dim = hidden_.getDim();

TensorDim input_step_dim = input_dim;
TensorDim hidden_step_dim = hidden_dim;

if (from) {
NNTR_THROW_IF(to - from != 1, std::invalid_argument)
<< "incremental step size is not 1";
from = 0;
to = 1;
}

input_step_dim.height(to - from);
hidden_step_dim.height(to - from);

// @todo: set reset stride as false. This implementation only works when
// batch size is 1
Tensor input_step = input_.getSharedDataTensor(input_step_dim, 0, true);
Tensor hidden_step = hidden_.getSharedDataTensor(hidden_step_dim, 0, true);

auto &epsilon = std::get<props::Epsilon>(fc_rms_props).get();

auto disable_bias = std::get<props::DisableBias>(*layer_impl_props);
bool disable_bias_value = disable_bias.empty() || disable_bias.get() == false;
Tensor &bias = context.getWeight(weight_idx[FC_RMSParams::bias]);

fusedProcess(input_step, weight, hidden_step, bias, disable_bias_value, gamma,
epsilon);
}

void FullyConnectedRMSNormLayerCl::calcDerivative(RunLayerContext &context) {
Tensor &weight = context.getWeight(weight_idx[FC_RMSParams::weight]);

const Tensor &derivative_ = context.getIncomingDerivative(SINGLE_INOUT_IDX);
Tensor &ret_ = context.getOutgoingDerivative(SINGLE_INOUT_IDX);

ret_.dot_deriv_wrt_1(weight, derivative_, false, false);
}

void FullyConnectedRMSNormLayerCl::calcGradient(RunLayerContext &context) {
Tensor &djdw = context.getWeightGrad(weight_idx[FC_RMSParams::weight]);

const Tensor &derivative_ = context.getIncomingDerivative(SINGLE_INOUT_IDX);
Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);

if (auto &disable_bias = std::get<props::DisableBias>(*layer_impl_props);
disable_bias.empty() || disable_bias.get() == false) {
Tensor &djdb = context.getWeightGrad(weight_idx[FC_RMSParams::bias]);

if (context.isGradientFirstAccess(weight_idx[FC_RMSParams::bias])) {
derivative_.sum({0, 1, 2}, djdb);
} else {
/// @todo optimize below by adding beta to Tensor::sum
Tensor t = derivative_.sum({0, 1, 2});
djdb.add_i(t);
}
}

input_.dot_deriv_wrt_2(
djdw, derivative_, false, false,
!context.isGradientFirstAccess(weight_idx[FC_RMSParams::weight]));
}

} /* namespace nntrainer */
Loading