From 781e16e39e72044cb064ef392b65c5b25d6ccc5c Mon Sep 17 00:00:00 2001 From: Yash Singh Date: Wed, 9 Oct 2024 15:14:24 +0530 Subject: [PATCH] [GPU/OpenCL/Update] Initial version of LM Head layer with OpencCl ops and Update Addition Layer on GPU with latest Pipeline changes Added initial version of LM head layer fpr GPU and removed dependencies of cl_context for addition_layer. Signed-off-by: Yash Singh --- api/ccapi/include/layer.h | 13 +- api/nntrainer-api-common.h | 1 + nntrainer/cl_context.cpp | 12 +- .../layers/cl_layers/addition_layer_cl.cpp | 4 +- .../layers/cl_layers/addition_layer_cl.h | 3 +- .../cl_layers/custom_vocab_selection.cpp | 131 +++++++++ .../layers/cl_layers/custom_vocab_selection.h | 62 +++++ .../layers/cl_layers/lm_head_layer_cl.cpp | 260 ++++++++++++++++++ nntrainer/layers/cl_layers/lm_head_layer_cl.h | 158 +++++++++++ nntrainer/layers/cl_layers/meson.build | 4 +- nntrainer/utils/custom_properties.h | 43 +++ test/input_gen/gen_layer_tests.py | 5 + test/jni/Android.mk | 1 + .../layers/unittest_layers_lm_head_cl.cpp | 50 ++++ 14 files changed, 739 insertions(+), 8 deletions(-) create mode 100644 nntrainer/layers/cl_layers/custom_vocab_selection.cpp create mode 100644 nntrainer/layers/cl_layers/custom_vocab_selection.h create mode 100644 nntrainer/layers/cl_layers/lm_head_layer_cl.cpp create mode 100644 nntrainer/layers/cl_layers/lm_head_layer_cl.h create mode 100644 nntrainer/utils/custom_properties.h create mode 100644 test/unittest/layers/unittest_layers_lm_head_cl.cpp diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index d9f9cffdd2..8fcf82b5fe 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -8,6 +8,7 @@ * @author Parichay Kapoor * @author Debadri Samaddar * @author Niket Agarwal + * @author Yash Singh * @bug No known bugs except for NYI items * @brief This is layers interface for c++ API * @@ -103,7 +104,8 @@ enum LayerType { derivative */ LAYER_UPSAMPLE2D, /**< Upsample 2D Layer type */ LAYER_RMSNORM = ML_TRAIN_LAYER_TYPE_RMSNORM, /** &properties = {}, return createLayer(LayerType::LAYER_ADDITION, properties, compute_engine); } +/** + * @brief Helper function to create lm_head layer + */ +inline std::unique_ptr +LmHead(const std::vector &properties = {}, + const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { + return createLayer(LayerType::LAYER_LM_HEAD, properties, compute_engine); +} + /** * @brief Helper function to create concat layer */ diff --git a/api/nntrainer-api-common.h b/api/nntrainer-api-common.h index 97a5a71fad..bd1d0b5f21 100644 --- a/api/nntrainer-api-common.h +++ b/api/nntrainer-api-common.h @@ -65,6 +65,7 @@ typedef enum { ML_TRAIN_LAYER_TYPE_IDENTITY = 29, /**< Identity Layer type (Since 8.0) */ ML_TRAIN_LAYER_TYPE_SWIGLU = 30, /**< Swiglu Layer type */ ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/ + ML_TRAIN_LAYER_TYPE_LM_HEAD = 32, /**< LM Head Layer type */ ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP = 300, /**< Preprocess flip Layer (Since 6.5) */ ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE = diff --git a/nntrainer/cl_context.cpp b/nntrainer/cl_context.cpp index 821a32d6fa..f7d9a7a3f1 100644 --- a/nntrainer/cl_context.cpp +++ b/nntrainer/cl_context.cpp @@ -7,6 +7,7 @@ * @see https://github.com/nnstreamer/nntrainer * @author Debadri Samaddar * @author Niket Agarwal + * @author Yash Singh * @author Thummala Pallavi * @bug No known bugs except for NYI items * @brief This file contains app context related functions and classes that @@ -19,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -35,9 +37,9 @@ static void add_default_object(ClContext &cc) { // FullyConnectedLayerCl::type, // ml::train::LayerType::LAYER_FC); - // cc.registerFactory(nntrainer::createLayer, - // AdditionLayerCL::type, - // ml::train::LayerType::LAYER_ADDITION); + cc.registerFactory(nntrainer::createLayer, + AdditionLayerCL::type, + ml::train::LayerType::LAYER_ADDITION); // cc.registerFactory(nntrainer::createLayer, // SwiGLULayerCl::type, @@ -54,6 +56,10 @@ static void add_default_object(ClContext &cc) { // cc.registerFactory(nntrainer::createLayer, // ConcatLayerCl::type, // ml::train::LayerType::LAYER_CONCAT); + + cc.registerFactory(nntrainer::createLayer, + CustomLMHeadLayerCl::type, + ml::train::LayerType::LAYER_LM_HEAD); } static void registerer(ClContext &cc) noexcept { diff --git a/nntrainer/layers/cl_layers/addition_layer_cl.cpp b/nntrainer/layers/cl_layers/addition_layer_cl.cpp index dda2101645..3fd57e20ee 100644 --- a/nntrainer/layers/cl_layers/addition_layer_cl.cpp +++ b/nntrainer/layers/cl_layers/addition_layer_cl.cpp @@ -37,7 +37,7 @@ void AdditionLayerCL::forwarding(RunLayerContext &context, bool training) { if (!idx) { hidden_.copy(input_); } else { - add_i_cl(input_, hidden_, context); + add_i_cl(input_, hidden_); } } } @@ -77,7 +77,7 @@ void AdditionLayerCL::incremental_forwarding(RunLayerContext &context, if (!idx) { hidden_step.copy(input_step); } else { - add_i_cl(input_step, hidden_step, context); + add_i_cl(input_step, hidden_step); } } } diff --git a/nntrainer/layers/cl_layers/addition_layer_cl.h b/nntrainer/layers/cl_layers/addition_layer_cl.h index e24354e8d0..f851e35138 100644 --- a/nntrainer/layers/cl_layers/addition_layer_cl.h +++ b/nntrainer/layers/cl_layers/addition_layer_cl.h @@ -15,6 +15,7 @@ #define __ADDITION_LAYER_CL_H__ #ifdef __cplusplus +#include #include #include @@ -40,7 +41,7 @@ class AdditionLayerCL : public Layer { /** * @brief Destructor of Addition Layer */ - ~AdditionLayerCL(){}; + ~AdditionLayerCL() {}; /** * @brief Move constructor of AdditionLayer. diff --git a/nntrainer/layers/cl_layers/custom_vocab_selection.cpp b/nntrainer/layers/cl_layers/custom_vocab_selection.cpp new file mode 100644 index 0000000000..ec3016a927 --- /dev/null +++ b/nntrainer/layers/cl_layers/custom_vocab_selection.cpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh > + * + * @file custom_vocab_selection.cpp + * @date 1 Oct 2024 + * @brief Implementation of custom vocab selection + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#include "custom_vocab_selection.h" +#include + +nntrainer::VocabSelection::VocabSelection(LshType lshType, int lshChoices, + int hiddenSize, int vocabCnt) : + lshType(lshType), + lshChoices(lshChoices), + vocabCnt(vocabCnt), + hiddenSize(hiddenSize), + lshBlockNum(0), + lshBits(0) {} + +nntrainer::VocabSelection::~VocabSelection() {} + +nntrainer::VocabSelectionNNTrainer::VocabSelectionNNTrainer( + LshType lshType, int lshChoices, int hiddenSize, int vocabCnt, + nntrainer::Tensor &weights) : + VocabSelection(lshType, lshChoices, hiddenSize, vocabCnt) { + this->lshBlockNum = (hiddenSize + lshBlockSize - 1) / lshBlockSize; + this->lshBits = lshBlockNum * lshBlockSize; + this->lshData = std::vector(this->vocabCnt * lshBlockNum); + + // for (unsigned int i = 0; i < vocabCnt; ++i) { + // for (unsigned int j = 0; j < lshBlockNum; ++j) { + // unsigned int actualSize = std::min(lshBlockSize, hiddenSize - + // (int)j * lshBlockSize); lshDataBlock d; for (unsigned int k = 0; k + // < actualSize; ++k) { + // d[k] = weights.getValue<_FP16>(0, 0, i, j * lshBlockSize + k) > + // 0 ? 1 : 0; + // } + // for (unsigned int k = actualSize; k < lshBlockSize; ++k) { + // d[k] = 0; + // } + // this->lshData[i * lshBlockNum + j] = d; + // } + // } + + for (unsigned int i = 0; i < lshBlockNum; ++i) { + unsigned int actualSize = + std::min(lshBlockSize, hiddenSize - (int)i * lshBlockSize); + for (unsigned int j = 0; j < vocabCnt; ++j) { + lshDataBlock d; + for (unsigned int k = 0; k < actualSize; ++k) { + if (weights.getDataType() == nntrainer::TensorDim::DataType::FP32) { + d[k] = weights.getValue(0, 0, i * lshBlockSize + k, j) > 0 ? 1 : 0; + } else if (weights.getDataType() == + nntrainer::TensorDim::DataType::FP16) { + d[k] = + weights.getValue<_FP16>(0, 0, i * lshBlockSize + k, j) > 0 ? 1 : 0; + } + } + for (unsigned int k = actualSize; k < lshBlockSize; ++k) { + d[k] = 0; + } + this->lshData[j * lshBlockNum + i] = d; + } + } +} + +std::vector> +nntrainer::VocabSelectionNNTrainer::getVocabs(const nntrainer::Tensor &input) { + unsigned int batchSize = input.height(); + + std::vector> res = std::vector>(batchSize); + for (int i = 0; i < batchSize; i++) { + std::vector d(lshBlockNum); + for (int k = 0; k < lshBlockNum; k++) { + int actualSize = std::min(lshBlockSize, hiddenSize - k * lshBlockSize); + for (int j = 0; j < actualSize; j++) { + if (input.getDataType() == nntrainer::TensorDim::DataType::FP32) { + d[k][j] = input.getValue(0, 0, i, j + k * lshBlockSize) >= 0 ? 1 : 0; + } else if (input.getDataType() == + nntrainer::TensorDim::DataType::FP16) { + d[k][j] = + input.getValue<_FP16>(0, 0, i, j + k * lshBlockSize) >= 0 ? 1 : 0; + } + } + for (int j = actualSize; j < lshBlockSize; j++) { + d[k][j] = 0; + } + } + std::vector simResult(vocabCnt, 0); + std::vector simCount(lshBits + 1, 0); + for (int j = 0; j < vocabCnt; j++) { + for (int k = 0; k < lshBlockNum; k++) { + simResult[j] += (d[k] ^ lshData[j * lshBlockNum + k]).count(); + } + simCount[simResult[j]]++; + } + int cut = lshBits + 1; + int leftover = 0; + int countSum = 0; + for (int j = 0; j <= lshBits; j++) { + countSum += simCount[j]; + if (countSum > lshChoices) { + cut = j; + leftover = simCount[j] - (countSum - lshChoices); + break; + } + } + std::vector selectedVocabs(lshChoices); + int pos = 0; + for (int j = 0; j < vocabCnt; j++) { + if (simResult[j] <= cut) { + if (simResult[j] < cut) { + selectedVocabs[pos] = j; + pos++; + } else if (leftover > 0) { + selectedVocabs[pos] = j; + pos++; + leftover--; + } + } + } + res[i] = selectedVocabs; + } + return res; +} diff --git a/nntrainer/layers/cl_layers/custom_vocab_selection.h b/nntrainer/layers/cl_layers/custom_vocab_selection.h new file mode 100644 index 0000000000..933e4e2ad0 --- /dev/null +++ b/nntrainer/layers/cl_layers/custom_vocab_selection.h @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh > + * + * @file custom_vocab_selection.h + * @date 1 Oct 2024 + * @brief Implementation of custom vocab selection + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#ifndef VOCAB_SELECTION_H +#define VOCAB_SELECTION_H + +#include + +#ifndef LSH_BLOCK_SIZE +#define LSH_BLOCK_SIZE 256 +#endif + +using namespace std; + +namespace nntrainer { + +enum LshType { NONE = 0, SIMHASH = 1, ORTHOSIMHASH = 2 }; +typedef std::bitset lshDataBlock; + +class VocabSelection { +protected: + int hiddenSize; + int vocabCnt; + const int lshBlockSize = LSH_BLOCK_SIZE; + int lshBlockNum; + int lshBits; // lshBlockSize * lshBlockNum + int lshChoices; + LshType lshType; + std::vector lshData; + +public: + VocabSelection(LshType lshType, int lshChoices, int hiddenSize, int vocabCnt); + virtual std::vector> + getVocabs(const nntrainer::Tensor &modelOutput) = 0; + ~VocabSelection(); +}; + +class VocabSelectionNNTrainer : public VocabSelection { +protected: + nntrainer::Tensor lshWeight; + +public: + VocabSelectionNNTrainer(LshType lshType, int lshChoices, int hiddenSize, + int vocabCnt, nntrainer::Tensor &weights); + virtual std::vector> + getVocabs(const nntrainer::Tensor &modelOutput); + ~VocabSelectionNNTrainer() {}; +}; + +} // namespace nntrainer + +#endif diff --git a/nntrainer/layers/cl_layers/lm_head_layer_cl.cpp b/nntrainer/layers/cl_layers/lm_head_layer_cl.cpp new file mode 100644 index 0000000000..d02609f15d --- /dev/null +++ b/nntrainer/layers/cl_layers/lm_head_layer_cl.cpp @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh > + * + * @file lm_head_layer_cl.cpp + * @date 1 Oct 2024 + * @brief Implementation of custom lm head layer + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#include +#include +#include +#include + +#include + +namespace nntrainer { + +static constexpr size_t SINGLE_INOUT_IDX = 0; + +enum LMHeadParams { weight, bias, candidate_weight, candidate_hidden_step }; + +CustomLMHeadLayerCl::CustomLMHeadLayerCl() : + LayerImpl(), + custom_lm_head_props(nntrainer::props::Unit(), props::UseVocabSelection(), + props::LshChoices(), props::SmartReply()) { + weight_idx.fill(std::numeric_limits::max()); +} + +void CustomLMHeadLayerCl::finalize(nntrainer::InitLayerContext &context) { + auto &weight_regularizer = + std::get(*layer_impl_props); + auto &weight_regularizer_constant = + std::get(*layer_impl_props); + auto weight_initializer = nntrainer::props::InitializerInfo::Enum::ZEROS; + // auto &weight_initializer = + // std::get(*layer_impl_props); + auto &weight_decay = + std::get(*layer_impl_props); + auto &bias_decay = std::get(*layer_impl_props); + auto &bias_initializer = + std::get(*layer_impl_props); + auto &disable_bias = + std::get(*layer_impl_props); + + auto unit = std::get(custom_lm_head_props).get(); + + NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) + << "lm head layer takes only one input"; + + std::vector output_dims(1); + + /// @todo fc actaully supports multidimensions. EffDimFlag shouldn't be fixed + /// like this. + context.setEffDimFlagInputDimension(0, 0b1001); + context.setDynDimFlagInputDimension(0, 0b1000); + + bool is_nchw = (context.getFormat() == nntrainer::Tformat::NCHW); + /** set output dimensions */ + auto const &in_dim = context.getInputDimensions()[0]; + output_dims[0] = in_dim; + is_nchw ? output_dims[0].width(unit) : output_dims[0].channel(unit); + + output_dims[0].setTensorType( + {context.getFormat(), context.getActivationDataType()}); + + context.setOutputDimensions(output_dims); + + /** set weight specifications */ + // @todo : This NCHW format setting is just temporal, it needs to be set by + // global configuration + ml::train::TensorDim bias_dim( + 1, is_nchw ? 1 : unit, 1, is_nchw ? unit : 1, + ml::train::TensorDim::TensorType(context.getFormat(), + context.getWeightDataType()), + is_nchw ? 0b0001 : 0b0100); + + ml::train::TensorDim weight_dim( + 1, is_nchw ? 1 : unit, is_nchw ? in_dim.width() : 1, + is_nchw ? unit : in_dim.channel(), + ml::train::TensorDim::TensorType(context.getFormat(), + context.getWeightDataType()), + is_nchw ? 0b0011 : 0b0101); + + weight_idx[LMHeadParams::weight] = context.requestWeight( + weight_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, weight_decay, "weight", true); + + if (disable_bias.empty() || disable_bias.get() == false) { + weight_idx[LMHeadParams::bias] = context.requestWeight( + bias_dim, bias_initializer, nntrainer::WeightRegularizer::NONE, 1.0f, + bias_decay, "bias", true); + } + + auto use_vocab_selection = + std::get(custom_lm_head_props).get(); + + if (use_vocab_selection) { + auto lsh_choices = std::get(custom_lm_head_props).get(); + + ml::train::TensorDim candidate_weight_dim( + 1, is_nchw ? 1 : lsh_choices, is_nchw ? lsh_choices : in_dim.channel(), + is_nchw ? in_dim.width() : 1, + ml::train::TensorDim::TensorType(context.getFormat(), + context.getWeightDataType())); + + weight_idx[LMHeadParams::candidate_weight] = context.requestTensor( + candidate_weight_dim, "candidate_weight", Initializer::NONE, false, + nntrainer::TensorLifespan::ITERATION_LIFESPAN); + + ml::train::TensorDim candidate_hidden_step_dim( + 1, 1, 1, lsh_choices, + ml::train::TensorDim::TensorType(context.getFormat(), + context.getWeightDataType())); + + weight_idx[LMHeadParams::candidate_hidden_step] = context.requestTensor( + candidate_hidden_step_dim, "candidate_hidden_step", Initializer::NONE, + false, nntrainer::TensorLifespan::ITERATION_LIFESPAN); + } +} + +void CustomLMHeadLayerCl::forwarding(nntrainer::RunLayerContext &context, + bool training) { + // NYI +} + +void CustomLMHeadLayerCl::initVocabSelection( + LshType lshType, int lshChoices, nntrainer::RunLayerContext &context) { + nntrainer::Tensor w; + nntrainer::Tensor &weight = w; + context.getWeight(weight, weight_idx[LMHeadParams::weight]); + this->vocabSelection = + std::shared_ptr(new VocabSelectionNNTrainer( + lshType, lshChoices, weight.height(), weight.width(), weight)); + weight_T = std::make_unique(weight.transpose("0:2:1")); + + weight_T->reshape({weight_T->width(), weight_T->height()}); +} + +void CustomLMHeadLayerCl::incremental_forwarding( + nntrainer::RunLayerContext &context, unsigned int from, unsigned int to, + bool training) { + nntrainer::Tensor w; + nntrainer::Tensor &weight = w; + context.getWeight(weight, weight_idx[LMHeadParams::weight]); + + nntrainer::Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); + nntrainer::Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + + ml::train::TensorDim input_dim = input_.getDim(); + ml::train::TensorDim hidden_dim = hidden_.getDim(); + + ml::train::TensorDim input_step_dim = input_dim; + ml::train::TensorDim hidden_step_dim = hidden_dim; + + unsigned int _from = from; + + if (from) { + NNTR_THROW_IF(to - from != 1, std::invalid_argument) + << "incremental step size is not 1"; + from = 0; + to = 1; + } + + input_step_dim.batch(1); + input_step_dim.height(1); + hidden_step_dim.batch(1); + hidden_step_dim.height(1); + + bool smart_reply = std::get(custom_lm_head_props).get(); + + unsigned int b_size = input_dim.batch(); + unsigned omp_num = 4; + if (smart_reply && !_from) { + b_size = 1; + omp_num = 1; + } + + // #pragma omp parallel for num_threads(omp_num) + for (unsigned int b = 0; b < b_size; ++b) { + nntrainer::Tensor input_step = input_.getSharedDataTensor( + input_step_dim, + b * input_dim.getFeatureLen() + + (to - from == 1 ? 0 : (to - 1) * input_.width()), + true); + nntrainer::Tensor hidden_step = hidden_.getSharedDataTensor( + hidden_step_dim, + b * hidden_dim.getFeatureLen() + + (to - from == 1 ? 0 : (to - 1) * hidden_.width()), + true); + + auto unit = std::get(custom_lm_head_props).get(); + auto use_vocab_selection = + std::get(custom_lm_head_props).get(); + + if (use_vocab_selection) { + auto lsh_choices = + std::get(custom_lm_head_props).get(); + auto vocab = vocabSelection->getVocabs(input_step); + + hidden_step.setValue(0); + + ml::train::TensorDim weight_T_ith_choice_dim = weight_T->getDim(); + weight_T_ith_choice_dim.width(1); + ml::train::TensorDim hidden_step_ith_choice_dim = hidden_step_dim; + hidden_step_ith_choice_dim.width(1); + nntrainer::Tensor weight_T_ith_choice; + + for (unsigned int i = 0; i < lsh_choices; ++i) { + weight_T_ith_choice = weight_T->getSharedDataTensor( + weight_T_ith_choice_dim, vocab[0][i] * input_step.width(), true); + nntrainer::Tensor hidden_step_ith_choice = + hidden_step.getSharedDataTensor(hidden_step_ith_choice_dim, + vocab[0][i], true); + + dotCl(input_step, weight_T_ith_choice, hidden_step_ith_choice); + } + } else { + dotCl(input_step, weight, hidden_step); + } + + if (auto &disable_bias = + std::get(*layer_impl_props); + disable_bias.empty() || disable_bias.get() == false) { + nntrainer::Tensor &bias = + context.getWeight(weight_idx[LMHeadParams::bias]); + + add_i_cl(bias, hidden_step); + } + } +} + +void CustomLMHeadLayerCl::calcDerivative(nntrainer::RunLayerContext &context) {} + +void CustomLMHeadLayerCl::setProperty(const std::vector &values) { + auto remain_props = loadProperties(values, custom_lm_head_props); + LayerImpl::setProperty(remain_props); +} + +#ifdef PLUGGABLE + +nntrainer::Layer *create_custom_lm_head_layer() { + auto layer = new CustomLMHeadLayerCl(); + return layer; +} + +void destroy_custom_lm_head_layer(nntrainer::Layer *layer) { delete layer; } + +extern "C" { +nntrainer::LayerPluggable ml_train_layer_pluggable{ + create_custom_lm_head_layer, destroy_custom_lm_head_layer}; +} + +#endif + +} // namespace nntrainer diff --git a/nntrainer/layers/cl_layers/lm_head_layer_cl.h b/nntrainer/layers/cl_layers/lm_head_layer_cl.h new file mode 100644 index 0000000000..9144f9fc85 --- /dev/null +++ b/nntrainer/layers/cl_layers/lm_head_layer_cl.h @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file lm_head_layer_cl.h + * @date 1 Oct 2024 + * @brief Implementation of custom lm head layer for GPU + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#ifndef __CUSTOM_LM_HEAD_LAYER_H__ +#define __CUSTOM_LM_HEAD_LAYER_H__ + +#include +#include +#include +#include +#include +#include + +namespace nntrainer { + +namespace props { + +/** + * @brief indicated whether do vocab selection or not + * + */ +class UseVocabSelection : public nntrainer::Property { +public: + /** + * @brief Construct a new UseVocabSelection object + * + */ + UseVocabSelection(bool value = false) { set(value); } + static constexpr const char *key = "use_vocab_selection"; + using prop_tag = nntrainer::bool_prop_tag; +}; + +/** + * @brief LshChoices property, indicate how many words will be choose + * + */ +class LshChoices : public nntrainer::PositiveIntegerProperty { +public: + /** + * @brief Construct a new LshChoices object with a default value 1 + * + */ + LshChoices(unsigned int value = 1) { set(value); }; + static constexpr const char *key = "lsh_choices"; /**< unique key to access */ + using prop_tag = nntrainer::uint_prop_tag; /**< property type */ +}; + +} // namespace props + +/** + * @brief A Custom LM Head layer for llama. + * + */ +class CustomLMHeadLayerCl : public LayerImpl { +public: + /** + * @brief Construct a new Custom LM Head layer object + * + */ + CustomLMHeadLayerCl(); + + /** + * @brief Destroy the Custom LM Head layer object + * + */ + ~CustomLMHeadLayerCl() {} + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(nntrainer::InitLayerContext &context) override; + + void initialize(nntrainer::RunLayerContext &context) override { + auto use_vocab_selection = + std::get(custom_lm_head_props).get(); + + if (use_vocab_selection) { + auto lsh_choices = + std::get(custom_lm_head_props).get(); + initVocabSelection(LshType::ORTHOSIMHASH, lsh_choices, context); + } + } + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(nntrainer::RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned + * int from, unsigned int to, bool training) + */ + void incremental_forwarding(nntrainer::RunLayerContext &context, + unsigned int from, unsigned int to, + bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(nntrainer::RunLayerContext &context) override; + + /** + * @copydoc Layer::calcGradient(RunLayerContext &context) + */ + // void calcGradient(nntrainer::RunLayerContext &context) override; + + /** + * @copydoc bool supportBackwarding() const + */ + bool supportBackwarding() const override { return true; }; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method) + */ + void exportTo(nntrainer::Exporter &exporter, + const ml::train::ExportMethods &method) const override {}; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { + return CustomLMHeadLayerCl::type; + }; + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) override; + + void initVocabSelection(LshType lshType, int lshChoices, + nntrainer::RunLayerContext &context); + + inline static const std::string type = "custom_lm_head"; + + std::shared_ptr vocabSelection; + +private: + std::tuple + custom_lm_head_props; + std::array weight_idx; /**< indices of the weights */ + std::unique_ptr + weight_T; // temporary weight. will be removed +}; + +} // namespace nntrainer + +#endif /* __LM_HEAD_LAYER_CL_H__ */ diff --git a/nntrainer/layers/cl_layers/meson.build b/nntrainer/layers/cl_layers/meson.build index 906b8a85ac..3f7d9f291d 100644 --- a/nntrainer/layers/cl_layers/meson.build +++ b/nntrainer/layers/cl_layers/meson.build @@ -1,10 +1,12 @@ cl_layer_sources = [ # 'fc_layer_cl.cpp', - # 'addition_layer_cl.cpp', + 'addition_layer_cl.cpp', # 'swiglu_cl.cpp', # 'reshape_cl.cpp', # 'rmsnorm_layer_cl.cpp', # 'concat_cl.cpp', + 'lm_head_layer_cl.cpp', + 'custom_vocab_selection.cpp', ] foreach s : cl_layer_sources diff --git a/nntrainer/utils/custom_properties.h b/nntrainer/utils/custom_properties.h new file mode 100644 index 0000000000..f0e68d3987 --- /dev/null +++ b/nntrainer/utils/custom_properties.h @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Debadri Samaddar + * + * @file custom_properties.h + * @date 1 October 2024 + * @brief This file contains list of custom properties widely + * used across custom layers + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#ifndef __CUSTOM_PROPERTIES_H__ +#define __CUSTOM_PROPERTIES_H__ + +#include + +namespace nntrainer { + +namespace props { + +/** + * @brief indicated this layer is for smart reply application + * + */ +class SmartReply : public Property { +public: + /** + * @brief Construct a new SmartReply object + * + */ + SmartReply(bool value = false) { set(value); } + static constexpr const char *key = "smart_reply"; + using prop_tag = bool_prop_tag; +}; + +} // namespace props + +} // namespace nntrainer + +#endif /* __CUSTOM_PROPERTIES_H__ */ diff --git a/test/input_gen/gen_layer_tests.py b/test/input_gen/gen_layer_tests.py index 1300dcd8d7..85ab74f9b0 100644 --- a/test/input_gen/gen_layer_tests.py +++ b/test/input_gen/gen_layer_tests.py @@ -954,3 +954,8 @@ def call(self, inputs): rms_normtest_fp16 = RMSNorm() record_single(rms_normtest,(2,3,3,3),"rms_normtest") record_single_fp16(rms_normtest_fp16,(2,3,3,3),"rms_normtest_fp16_new") + + lm_head = K.layers.Dense(5) + record_single(lm_head, (3, 1, 1, 10), "lm_head_GPU1") + lm_head1616 = K.layers.Dense(5) + record_single_fp16(lm_head1616, (3, 1, 1, 10), "lm_head_GPU1_w16a16") diff --git a/test/jni/Android.mk b/test/jni/Android.mk index 153b4eb840..e8921e800d 100644 --- a/test/jni/Android.mk +++ b/test/jni/Android.mk @@ -442,6 +442,7 @@ LOCAL_SRC_FILES := \ ../unittest/layers/unittest_layer_node.cpp \ ../unittest/layers/unittest_layers.cpp \ ../unittest/layers/unittest_layers_impl.cpp \ + ../unittest/layers/unittest_layers_lm_head_cl.cpp \ ../unittest/layers/unittest_layers_concat_cl.cpp \ ../unittest/layers/unittest_layers_swiglu_cl.cpp \ ../unittest/layers/unittest_layers_fully_connected_cl.cpp \ diff --git a/test/unittest/layers/unittest_layers_lm_head_cl.cpp b/test/unittest/layers/unittest_layers_lm_head_cl.cpp new file mode 100644 index 0000000000..c0224a5d9b --- /dev/null +++ b/test/unittest/layers/unittest_layers_lm_head_cl.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file unittest_layers_lm_head_cl.cpp + * @date 1 Oct 2024 + * @brief LM Head Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + */ + +#include +#include +#include +#include + +auto semantic_lm_head_gpu = LayerSemanticsParamType( + nntrainer::createLayer, + nntrainer::CustomLMHeadLayerCl::type, {"unit=1"}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1); + +GTEST_PARAMETER_TEST(LM_HeadGPU, LayerSemanticsGpu, + ::testing::Values(semantic_lm_head_gpu)); + +auto lm_head_gpu = LayerGoldenTestParamType( + nntrainer::createLayer, + {"unit=5", "use_vocab_selection=false"}, "3:10:1:1", + "lm_head_GPU1.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD | + LayerGoldenTestParamOptions::USE_INC_FORWARD, + "nhwc", "fp32", "fp32"); + +GTEST_PARAMETER_TEST(LM_HeadGPU, LayerGoldenTest, + ::testing::Values(lm_head_gpu)); + +#ifdef ENABLE_FP16 +auto lm_head_gpu_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"unit=5", "use_vocab_selection=false"}, "3:1:1:10", + "lm_head_GPU1_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD | + LayerGoldenTestParamOptions::USE_INC_FORWARD, + "nchw", "fp16", "fp16"); + +GTEST_PARAMETER_TEST(LM_HeadGPU16, LayerGoldenTest, + ::testing::Values(lm_head_gpu_w16a16)); +#endif