diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index d9f9cffdd2..8fcf82b5fe 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -8,6 +8,7 @@ * @author Parichay Kapoor * @author Debadri Samaddar * @author Niket Agarwal + * @author Yash Singh * @bug No known bugs except for NYI items * @brief This is layers interface for c++ API * @@ -103,7 +104,8 @@ enum LayerType { derivative */ LAYER_UPSAMPLE2D, /**< Upsample 2D Layer type */ LAYER_RMSNORM = ML_TRAIN_LAYER_TYPE_RMSNORM, /** &properties = {}, return createLayer(LayerType::LAYER_ADDITION, properties, compute_engine); } +/** + * @brief Helper function to create lm_head layer + */ +inline std::unique_ptr +LmHead(const std::vector &properties = {}, + const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { + return createLayer(LayerType::LAYER_LM_HEAD, properties, compute_engine); +} + /** * @brief Helper function to create concat layer */ diff --git a/api/nntrainer-api-common.h b/api/nntrainer-api-common.h index 97a5a71fad..bd1d0b5f21 100644 --- a/api/nntrainer-api-common.h +++ b/api/nntrainer-api-common.h @@ -65,6 +65,7 @@ typedef enum { ML_TRAIN_LAYER_TYPE_IDENTITY = 29, /**< Identity Layer type (Since 8.0) */ ML_TRAIN_LAYER_TYPE_SWIGLU = 30, /**< Swiglu Layer type */ ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/ + ML_TRAIN_LAYER_TYPE_LM_HEAD = 32, /**< LM Head Layer type */ ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP = 300, /**< Preprocess flip Layer (Since 6.5) */ ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE = diff --git a/nntrainer/cl_context.cpp b/nntrainer/cl_context.cpp index 821a32d6fa..f7d9a7a3f1 100644 --- a/nntrainer/cl_context.cpp +++ b/nntrainer/cl_context.cpp @@ -7,6 +7,7 @@ * @see https://github.com/nnstreamer/nntrainer * @author Debadri Samaddar * @author Niket Agarwal + * @author Yash Singh * @author Thummala Pallavi * @bug No known bugs except for NYI items * @brief This file contains app context related functions and classes that @@ -19,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -35,9 +37,9 @@ static void add_default_object(ClContext &cc) { // FullyConnectedLayerCl::type, // ml::train::LayerType::LAYER_FC); - // cc.registerFactory(nntrainer::createLayer, - // AdditionLayerCL::type, - // ml::train::LayerType::LAYER_ADDITION); + cc.registerFactory(nntrainer::createLayer, + AdditionLayerCL::type, + ml::train::LayerType::LAYER_ADDITION); // cc.registerFactory(nntrainer::createLayer, // SwiGLULayerCl::type, @@ -54,6 +56,10 @@ static void add_default_object(ClContext &cc) { // cc.registerFactory(nntrainer::createLayer, // ConcatLayerCl::type, // ml::train::LayerType::LAYER_CONCAT); + + cc.registerFactory(nntrainer::createLayer, + CustomLMHeadLayerCl::type, + ml::train::LayerType::LAYER_LM_HEAD); } static void registerer(ClContext &cc) noexcept { diff --git a/nntrainer/layers/cl_layers/addition_layer_cl.cpp b/nntrainer/layers/cl_layers/addition_layer_cl.cpp index dda2101645..3fd57e20ee 100644 --- a/nntrainer/layers/cl_layers/addition_layer_cl.cpp +++ b/nntrainer/layers/cl_layers/addition_layer_cl.cpp @@ -37,7 +37,7 @@ void AdditionLayerCL::forwarding(RunLayerContext &context, bool training) { if (!idx) { hidden_.copy(input_); } else { - add_i_cl(input_, hidden_, context); + add_i_cl(input_, hidden_); } } } @@ -77,7 +77,7 @@ void AdditionLayerCL::incremental_forwarding(RunLayerContext &context, if (!idx) { hidden_step.copy(input_step); } else { - add_i_cl(input_step, hidden_step, context); + add_i_cl(input_step, hidden_step); } } } diff --git a/nntrainer/layers/cl_layers/addition_layer_cl.h b/nntrainer/layers/cl_layers/addition_layer_cl.h index e24354e8d0..f851e35138 100644 --- a/nntrainer/layers/cl_layers/addition_layer_cl.h +++ b/nntrainer/layers/cl_layers/addition_layer_cl.h @@ -15,6 +15,7 @@ #define __ADDITION_LAYER_CL_H__ #ifdef __cplusplus +#include #include #include @@ -40,7 +41,7 @@ class AdditionLayerCL : public Layer { /** * @brief Destructor of Addition Layer */ - ~AdditionLayerCL(){}; + ~AdditionLayerCL() {}; /** * @brief Move constructor of AdditionLayer. diff --git a/nntrainer/layers/cl_layers/custom_vocab_selection.cpp b/nntrainer/layers/cl_layers/custom_vocab_selection.cpp new file mode 100644 index 0000000000..ec3016a927 --- /dev/null +++ b/nntrainer/layers/cl_layers/custom_vocab_selection.cpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh > + * + * @file custom_vocab_selection.cpp + * @date 1 Oct 2024 + * @brief Implementation of custom vocab selection + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#include "custom_vocab_selection.h" +#include + +nntrainer::VocabSelection::VocabSelection(LshType lshType, int lshChoices, + int hiddenSize, int vocabCnt) : + lshType(lshType), + lshChoices(lshChoices), + vocabCnt(vocabCnt), + hiddenSize(hiddenSize), + lshBlockNum(0), + lshBits(0) {} + +nntrainer::VocabSelection::~VocabSelection() {} + +nntrainer::VocabSelectionNNTrainer::VocabSelectionNNTrainer( + LshType lshType, int lshChoices, int hiddenSize, int vocabCnt, + nntrainer::Tensor &weights) : + VocabSelection(lshType, lshChoices, hiddenSize, vocabCnt) { + this->lshBlockNum = (hiddenSize + lshBlockSize - 1) / lshBlockSize; + this->lshBits = lshBlockNum * lshBlockSize; + this->lshData = std::vector(this->vocabCnt * lshBlockNum); + + // for (unsigned int i = 0; i < vocabCnt; ++i) { + // for (unsigned int j = 0; j < lshBlockNum; ++j) { + // unsigned int actualSize = std::min(lshBlockSize, hiddenSize - + // (int)j * lshBlockSize); lshDataBlock d; for (unsigned int k = 0; k + // < actualSize; ++k) { + // d[k] = weights.getValue<_FP16>(0, 0, i, j * lshBlockSize + k) > + // 0 ? 1 : 0; + // } + // for (unsigned int k = actualSize; k < lshBlockSize; ++k) { + // d[k] = 0; + // } + // this->lshData[i * lshBlockNum + j] = d; + // } + // } + + for (unsigned int i = 0; i < lshBlockNum; ++i) { + unsigned int actualSize = + std::min(lshBlockSize, hiddenSize - (int)i * lshBlockSize); + for (unsigned int j = 0; j < vocabCnt; ++j) { + lshDataBlock d; + for (unsigned int k = 0; k < actualSize; ++k) { + if (weights.getDataType() == nntrainer::TensorDim::DataType::FP32) { + d[k] = weights.getValue(0, 0, i * lshBlockSize + k, j) > 0 ? 1 : 0; + } else if (weights.getDataType() == + nntrainer::TensorDim::DataType::FP16) { + d[k] = + weights.getValue<_FP16>(0, 0, i * lshBlockSize + k, j) > 0 ? 1 : 0; + } + } + for (unsigned int k = actualSize; k < lshBlockSize; ++k) { + d[k] = 0; + } + this->lshData[j * lshBlockNum + i] = d; + } + } +} + +std::vector> +nntrainer::VocabSelectionNNTrainer::getVocabs(const nntrainer::Tensor &input) { + unsigned int batchSize = input.height(); + + std::vector> res = std::vector>(batchSize); + for (int i = 0; i < batchSize; i++) { + std::vector d(lshBlockNum); + for (int k = 0; k < lshBlockNum; k++) { + int actualSize = std::min(lshBlockSize, hiddenSize - k * lshBlockSize); + for (int j = 0; j < actualSize; j++) { + if (input.getDataType() == nntrainer::TensorDim::DataType::FP32) { + d[k][j] = input.getValue(0, 0, i, j + k * lshBlockSize) >= 0 ? 1 : 0; + } else if (input.getDataType() == + nntrainer::TensorDim::DataType::FP16) { + d[k][j] = + input.getValue<_FP16>(0, 0, i, j + k * lshBlockSize) >= 0 ? 1 : 0; + } + } + for (int j = actualSize; j < lshBlockSize; j++) { + d[k][j] = 0; + } + } + std::vector simResult(vocabCnt, 0); + std::vector simCount(lshBits + 1, 0); + for (int j = 0; j < vocabCnt; j++) { + for (int k = 0; k < lshBlockNum; k++) { + simResult[j] += (d[k] ^ lshData[j * lshBlockNum + k]).count(); + } + simCount[simResult[j]]++; + } + int cut = lshBits + 1; + int leftover = 0; + int countSum = 0; + for (int j = 0; j <= lshBits; j++) { + countSum += simCount[j]; + if (countSum > lshChoices) { + cut = j; + leftover = simCount[j] - (countSum - lshChoices); + break; + } + } + std::vector selectedVocabs(lshChoices); + int pos = 0; + for (int j = 0; j < vocabCnt; j++) { + if (simResult[j] <= cut) { + if (simResult[j] < cut) { + selectedVocabs[pos] = j; + pos++; + } else if (leftover > 0) { + selectedVocabs[pos] = j; + pos++; + leftover--; + } + } + } + res[i] = selectedVocabs; + } + return res; +} diff --git a/nntrainer/layers/cl_layers/custom_vocab_selection.h b/nntrainer/layers/cl_layers/custom_vocab_selection.h new file mode 100644 index 0000000000..933e4e2ad0 --- /dev/null +++ b/nntrainer/layers/cl_layers/custom_vocab_selection.h @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh > + * + * @file custom_vocab_selection.h + * @date 1 Oct 2024 + * @brief Implementation of custom vocab selection + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#ifndef VOCAB_SELECTION_H +#define VOCAB_SELECTION_H + +#include + +#ifndef LSH_BLOCK_SIZE +#define LSH_BLOCK_SIZE 256 +#endif + +using namespace std; + +namespace nntrainer { + +enum LshType { NONE = 0, SIMHASH = 1, ORTHOSIMHASH = 2 }; +typedef std::bitset lshDataBlock; + +class VocabSelection { +protected: + int hiddenSize; + int vocabCnt; + const int lshBlockSize = LSH_BLOCK_SIZE; + int lshBlockNum; + int lshBits; // lshBlockSize * lshBlockNum + int lshChoices; + LshType lshType; + std::vector lshData; + +public: + VocabSelection(LshType lshType, int lshChoices, int hiddenSize, int vocabCnt); + virtual std::vector> + getVocabs(const nntrainer::Tensor &modelOutput) = 0; + ~VocabSelection(); +}; + +class VocabSelectionNNTrainer : public VocabSelection { +protected: + nntrainer::Tensor lshWeight; + +public: + VocabSelectionNNTrainer(LshType lshType, int lshChoices, int hiddenSize, + int vocabCnt, nntrainer::Tensor &weights); + virtual std::vector> + getVocabs(const nntrainer::Tensor &modelOutput); + ~VocabSelectionNNTrainer() {}; +}; + +} // namespace nntrainer + +#endif diff --git a/nntrainer/layers/cl_layers/lm_head_layer_cl.cpp b/nntrainer/layers/cl_layers/lm_head_layer_cl.cpp new file mode 100644 index 0000000000..d02609f15d --- /dev/null +++ b/nntrainer/layers/cl_layers/lm_head_layer_cl.cpp @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh > + * + * @file lm_head_layer_cl.cpp + * @date 1 Oct 2024 + * @brief Implementation of custom lm head layer + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#include +#include +#include +#include + +#include + +namespace nntrainer { + +static constexpr size_t SINGLE_INOUT_IDX = 0; + +enum LMHeadParams { weight, bias, candidate_weight, candidate_hidden_step }; + +CustomLMHeadLayerCl::CustomLMHeadLayerCl() : + LayerImpl(), + custom_lm_head_props(nntrainer::props::Unit(), props::UseVocabSelection(), + props::LshChoices(), props::SmartReply()) { + weight_idx.fill(std::numeric_limits::max()); +} + +void CustomLMHeadLayerCl::finalize(nntrainer::InitLayerContext &context) { + auto &weight_regularizer = + std::get(*layer_impl_props); + auto &weight_regularizer_constant = + std::get(*layer_impl_props); + auto weight_initializer = nntrainer::props::InitializerInfo::Enum::ZEROS; + // auto &weight_initializer = + // std::get(*layer_impl_props); + auto &weight_decay = + std::get(*layer_impl_props); + auto &bias_decay = std::get(*layer_impl_props); + auto &bias_initializer = + std::get(*layer_impl_props); + auto &disable_bias = + std::get(*layer_impl_props); + + auto unit = std::get(custom_lm_head_props).get(); + + NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) + << "lm head layer takes only one input"; + + std::vector output_dims(1); + + /// @todo fc actaully supports multidimensions. EffDimFlag shouldn't be fixed + /// like this. + context.setEffDimFlagInputDimension(0, 0b1001); + context.setDynDimFlagInputDimension(0, 0b1000); + + bool is_nchw = (context.getFormat() == nntrainer::Tformat::NCHW); + /** set output dimensions */ + auto const &in_dim = context.getInputDimensions()[0]; + output_dims[0] = in_dim; + is_nchw ? output_dims[0].width(unit) : output_dims[0].channel(unit); + + output_dims[0].setTensorType( + {context.getFormat(), context.getActivationDataType()}); + + context.setOutputDimensions(output_dims); + + /** set weight specifications */ + // @todo : This NCHW format setting is just temporal, it needs to be set by + // global configuration + ml::train::TensorDim bias_dim( + 1, is_nchw ? 1 : unit, 1, is_nchw ? unit : 1, + ml::train::TensorDim::TensorType(context.getFormat(), + context.getWeightDataType()), + is_nchw ? 0b0001 : 0b0100); + + ml::train::TensorDim weight_dim( + 1, is_nchw ? 1 : unit, is_nchw ? in_dim.width() : 1, + is_nchw ? unit : in_dim.channel(), + ml::train::TensorDim::TensorType(context.getFormat(), + context.getWeightDataType()), + is_nchw ? 0b0011 : 0b0101); + + weight_idx[LMHeadParams::weight] = context.requestWeight( + weight_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, weight_decay, "weight", true); + + if (disable_bias.empty() || disable_bias.get() == false) { + weight_idx[LMHeadParams::bias] = context.requestWeight( + bias_dim, bias_initializer, nntrainer::WeightRegularizer::NONE, 1.0f, + bias_decay, "bias", true); + } + + auto use_vocab_selection = + std::get(custom_lm_head_props).get(); + + if (use_vocab_selection) { + auto lsh_choices = std::get(custom_lm_head_props).get(); + + ml::train::TensorDim candidate_weight_dim( + 1, is_nchw ? 1 : lsh_choices, is_nchw ? lsh_choices : in_dim.channel(), + is_nchw ? in_dim.width() : 1, + ml::train::TensorDim::TensorType(context.getFormat(), + context.getWeightDataType())); + + weight_idx[LMHeadParams::candidate_weight] = context.requestTensor( + candidate_weight_dim, "candidate_weight", Initializer::NONE, false, + nntrainer::TensorLifespan::ITERATION_LIFESPAN); + + ml::train::TensorDim candidate_hidden_step_dim( + 1, 1, 1, lsh_choices, + ml::train::TensorDim::TensorType(context.getFormat(), + context.getWeightDataType())); + + weight_idx[LMHeadParams::candidate_hidden_step] = context.requestTensor( + candidate_hidden_step_dim, "candidate_hidden_step", Initializer::NONE, + false, nntrainer::TensorLifespan::ITERATION_LIFESPAN); + } +} + +void CustomLMHeadLayerCl::forwarding(nntrainer::RunLayerContext &context, + bool training) { + // NYI +} + +void CustomLMHeadLayerCl::initVocabSelection( + LshType lshType, int lshChoices, nntrainer::RunLayerContext &context) { + nntrainer::Tensor w; + nntrainer::Tensor &weight = w; + context.getWeight(weight, weight_idx[LMHeadParams::weight]); + this->vocabSelection = + std::shared_ptr(new VocabSelectionNNTrainer( + lshType, lshChoices, weight.height(), weight.width(), weight)); + weight_T = std::make_unique(weight.transpose("0:2:1")); + + weight_T->reshape({weight_T->width(), weight_T->height()}); +} + +void CustomLMHeadLayerCl::incremental_forwarding( + nntrainer::RunLayerContext &context, unsigned int from, unsigned int to, + bool training) { + nntrainer::Tensor w; + nntrainer::Tensor &weight = w; + context.getWeight(weight, weight_idx[LMHeadParams::weight]); + + nntrainer::Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); + nntrainer::Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + + ml::train::TensorDim input_dim = input_.getDim(); + ml::train::TensorDim hidden_dim = hidden_.getDim(); + + ml::train::TensorDim input_step_dim = input_dim; + ml::train::TensorDim hidden_step_dim = hidden_dim; + + unsigned int _from = from; + + if (from) { + NNTR_THROW_IF(to - from != 1, std::invalid_argument) + << "incremental step size is not 1"; + from = 0; + to = 1; + } + + input_step_dim.batch(1); + input_step_dim.height(1); + hidden_step_dim.batch(1); + hidden_step_dim.height(1); + + bool smart_reply = std::get(custom_lm_head_props).get(); + + unsigned int b_size = input_dim.batch(); + unsigned omp_num = 4; + if (smart_reply && !_from) { + b_size = 1; + omp_num = 1; + } + + // #pragma omp parallel for num_threads(omp_num) + for (unsigned int b = 0; b < b_size; ++b) { + nntrainer::Tensor input_step = input_.getSharedDataTensor( + input_step_dim, + b * input_dim.getFeatureLen() + + (to - from == 1 ? 0 : (to - 1) * input_.width()), + true); + nntrainer::Tensor hidden_step = hidden_.getSharedDataTensor( + hidden_step_dim, + b * hidden_dim.getFeatureLen() + + (to - from == 1 ? 0 : (to - 1) * hidden_.width()), + true); + + auto unit = std::get(custom_lm_head_props).get(); + auto use_vocab_selection = + std::get(custom_lm_head_props).get(); + + if (use_vocab_selection) { + auto lsh_choices = + std::get(custom_lm_head_props).get(); + auto vocab = vocabSelection->getVocabs(input_step); + + hidden_step.setValue(0); + + ml::train::TensorDim weight_T_ith_choice_dim = weight_T->getDim(); + weight_T_ith_choice_dim.width(1); + ml::train::TensorDim hidden_step_ith_choice_dim = hidden_step_dim; + hidden_step_ith_choice_dim.width(1); + nntrainer::Tensor weight_T_ith_choice; + + for (unsigned int i = 0; i < lsh_choices; ++i) { + weight_T_ith_choice = weight_T->getSharedDataTensor( + weight_T_ith_choice_dim, vocab[0][i] * input_step.width(), true); + nntrainer::Tensor hidden_step_ith_choice = + hidden_step.getSharedDataTensor(hidden_step_ith_choice_dim, + vocab[0][i], true); + + dotCl(input_step, weight_T_ith_choice, hidden_step_ith_choice); + } + } else { + dotCl(input_step, weight, hidden_step); + } + + if (auto &disable_bias = + std::get(*layer_impl_props); + disable_bias.empty() || disable_bias.get() == false) { + nntrainer::Tensor &bias = + context.getWeight(weight_idx[LMHeadParams::bias]); + + add_i_cl(bias, hidden_step); + } + } +} + +void CustomLMHeadLayerCl::calcDerivative(nntrainer::RunLayerContext &context) {} + +void CustomLMHeadLayerCl::setProperty(const std::vector &values) { + auto remain_props = loadProperties(values, custom_lm_head_props); + LayerImpl::setProperty(remain_props); +} + +#ifdef PLUGGABLE + +nntrainer::Layer *create_custom_lm_head_layer() { + auto layer = new CustomLMHeadLayerCl(); + return layer; +} + +void destroy_custom_lm_head_layer(nntrainer::Layer *layer) { delete layer; } + +extern "C" { +nntrainer::LayerPluggable ml_train_layer_pluggable{ + create_custom_lm_head_layer, destroy_custom_lm_head_layer}; +} + +#endif + +} // namespace nntrainer diff --git a/nntrainer/layers/cl_layers/lm_head_layer_cl.h b/nntrainer/layers/cl_layers/lm_head_layer_cl.h new file mode 100644 index 0000000000..9144f9fc85 --- /dev/null +++ b/nntrainer/layers/cl_layers/lm_head_layer_cl.h @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file lm_head_layer_cl.h + * @date 1 Oct 2024 + * @brief Implementation of custom lm head layer for GPU + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#ifndef __CUSTOM_LM_HEAD_LAYER_H__ +#define __CUSTOM_LM_HEAD_LAYER_H__ + +#include +#include +#include +#include +#include +#include + +namespace nntrainer { + +namespace props { + +/** + * @brief indicated whether do vocab selection or not + * + */ +class UseVocabSelection : public nntrainer::Property { +public: + /** + * @brief Construct a new UseVocabSelection object + * + */ + UseVocabSelection(bool value = false) { set(value); } + static constexpr const char *key = "use_vocab_selection"; + using prop_tag = nntrainer::bool_prop_tag; +}; + +/** + * @brief LshChoices property, indicate how many words will be choose + * + */ +class LshChoices : public nntrainer::PositiveIntegerProperty { +public: + /** + * @brief Construct a new LshChoices object with a default value 1 + * + */ + LshChoices(unsigned int value = 1) { set(value); }; + static constexpr const char *key = "lsh_choices"; /**< unique key to access */ + using prop_tag = nntrainer::uint_prop_tag; /**< property type */ +}; + +} // namespace props + +/** + * @brief A Custom LM Head layer for llama. + * + */ +class CustomLMHeadLayerCl : public LayerImpl { +public: + /** + * @brief Construct a new Custom LM Head layer object + * + */ + CustomLMHeadLayerCl(); + + /** + * @brief Destroy the Custom LM Head layer object + * + */ + ~CustomLMHeadLayerCl() {} + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(nntrainer::InitLayerContext &context) override; + + void initialize(nntrainer::RunLayerContext &context) override { + auto use_vocab_selection = + std::get(custom_lm_head_props).get(); + + if (use_vocab_selection) { + auto lsh_choices = + std::get(custom_lm_head_props).get(); + initVocabSelection(LshType::ORTHOSIMHASH, lsh_choices, context); + } + } + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(nntrainer::RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned + * int from, unsigned int to, bool training) + */ + void incremental_forwarding(nntrainer::RunLayerContext &context, + unsigned int from, unsigned int to, + bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(nntrainer::RunLayerContext &context) override; + + /** + * @copydoc Layer::calcGradient(RunLayerContext &context) + */ + // void calcGradient(nntrainer::RunLayerContext &context) override; + + /** + * @copydoc bool supportBackwarding() const + */ + bool supportBackwarding() const override { return true; }; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method) + */ + void exportTo(nntrainer::Exporter &exporter, + const ml::train::ExportMethods &method) const override {}; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { + return CustomLMHeadLayerCl::type; + }; + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) override; + + void initVocabSelection(LshType lshType, int lshChoices, + nntrainer::RunLayerContext &context); + + inline static const std::string type = "custom_lm_head"; + + std::shared_ptr vocabSelection; + +private: + std::tuple + custom_lm_head_props; + std::array weight_idx; /**< indices of the weights */ + std::unique_ptr + weight_T; // temporary weight. will be removed +}; + +} // namespace nntrainer + +#endif /* __LM_HEAD_LAYER_CL_H__ */ diff --git a/nntrainer/layers/cl_layers/meson.build b/nntrainer/layers/cl_layers/meson.build index 906b8a85ac..3f7d9f291d 100644 --- a/nntrainer/layers/cl_layers/meson.build +++ b/nntrainer/layers/cl_layers/meson.build @@ -1,10 +1,12 @@ cl_layer_sources = [ # 'fc_layer_cl.cpp', - # 'addition_layer_cl.cpp', + 'addition_layer_cl.cpp', # 'swiglu_cl.cpp', # 'reshape_cl.cpp', # 'rmsnorm_layer_cl.cpp', # 'concat_cl.cpp', + 'lm_head_layer_cl.cpp', + 'custom_vocab_selection.cpp', ] foreach s : cl_layer_sources diff --git a/nntrainer/utils/custom_properties.h b/nntrainer/utils/custom_properties.h new file mode 100644 index 0000000000..f0e68d3987 --- /dev/null +++ b/nntrainer/utils/custom_properties.h @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Debadri Samaddar + * + * @file custom_properties.h + * @date 1 October 2024 + * @brief This file contains list of custom properties widely + * used across custom layers + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#ifndef __CUSTOM_PROPERTIES_H__ +#define __CUSTOM_PROPERTIES_H__ + +#include + +namespace nntrainer { + +namespace props { + +/** + * @brief indicated this layer is for smart reply application + * + */ +class SmartReply : public Property { +public: + /** + * @brief Construct a new SmartReply object + * + */ + SmartReply(bool value = false) { set(value); } + static constexpr const char *key = "smart_reply"; + using prop_tag = bool_prop_tag; +}; + +} // namespace props + +} // namespace nntrainer + +#endif /* __CUSTOM_PROPERTIES_H__ */ diff --git a/test/input_gen/gen_layer_tests.py b/test/input_gen/gen_layer_tests.py index 1300dcd8d7..85ab74f9b0 100644 --- a/test/input_gen/gen_layer_tests.py +++ b/test/input_gen/gen_layer_tests.py @@ -954,3 +954,8 @@ def call(self, inputs): rms_normtest_fp16 = RMSNorm() record_single(rms_normtest,(2,3,3,3),"rms_normtest") record_single_fp16(rms_normtest_fp16,(2,3,3,3),"rms_normtest_fp16_new") + + lm_head = K.layers.Dense(5) + record_single(lm_head, (3, 1, 1, 10), "lm_head_GPU1") + lm_head1616 = K.layers.Dense(5) + record_single_fp16(lm_head1616, (3, 1, 1, 10), "lm_head_GPU1_w16a16") diff --git a/test/jni/Android.mk b/test/jni/Android.mk index 153b4eb840..e8921e800d 100644 --- a/test/jni/Android.mk +++ b/test/jni/Android.mk @@ -442,6 +442,7 @@ LOCAL_SRC_FILES := \ ../unittest/layers/unittest_layer_node.cpp \ ../unittest/layers/unittest_layers.cpp \ ../unittest/layers/unittest_layers_impl.cpp \ + ../unittest/layers/unittest_layers_lm_head_cl.cpp \ ../unittest/layers/unittest_layers_concat_cl.cpp \ ../unittest/layers/unittest_layers_swiglu_cl.cpp \ ../unittest/layers/unittest_layers_fully_connected_cl.cpp \ diff --git a/test/unittest/layers/unittest_layers_lm_head_cl.cpp b/test/unittest/layers/unittest_layers_lm_head_cl.cpp new file mode 100644 index 0000000000..c0224a5d9b --- /dev/null +++ b/test/unittest/layers/unittest_layers_lm_head_cl.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file unittest_layers_lm_head_cl.cpp + * @date 1 Oct 2024 + * @brief LM Head Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + */ + +#include +#include +#include +#include + +auto semantic_lm_head_gpu = LayerSemanticsParamType( + nntrainer::createLayer, + nntrainer::CustomLMHeadLayerCl::type, {"unit=1"}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1); + +GTEST_PARAMETER_TEST(LM_HeadGPU, LayerSemanticsGpu, + ::testing::Values(semantic_lm_head_gpu)); + +auto lm_head_gpu = LayerGoldenTestParamType( + nntrainer::createLayer, + {"unit=5", "use_vocab_selection=false"}, "3:10:1:1", + "lm_head_GPU1.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD | + LayerGoldenTestParamOptions::USE_INC_FORWARD, + "nhwc", "fp32", "fp32"); + +GTEST_PARAMETER_TEST(LM_HeadGPU, LayerGoldenTest, + ::testing::Values(lm_head_gpu)); + +#ifdef ENABLE_FP16 +auto lm_head_gpu_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"unit=5", "use_vocab_selection=false"}, "3:1:1:10", + "lm_head_GPU1_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD | + LayerGoldenTestParamOptions::USE_INC_FORWARD, + "nchw", "fp16", "fp16"); + +GTEST_PARAMETER_TEST(LM_HeadGPU16, LayerGoldenTest, + ::testing::Values(lm_head_gpu_w16a16)); +#endif