Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wait for 2759][GPU/OpenCL] Initial version of LM Head layer with OpenCl ops and updated Addition Layer with latest pipeline changes. @open sesame 12/02 16:21 #2752

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion api/ccapi/include/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ enum LayerType {
LAYER_UPSAMPLE2D, /**< Upsample 2D Layer type */
LAYER_RMSNORM = ML_TRAIN_LAYER_TYPE_RMSNORM, /**<RMS NORM Layer */
LAYER_TRANSPOSE = ML_TRAIN_LAYER_TYPE_TRANSPOSE, /**< Transpose Layer type */
LAYER_UNKNOWN = ML_TRAIN_LAYER_TYPE_UNKNOWN /**< Unknown */
LAYER_LM_HEAD = ML_TRAIN_LAYER_TYPE_LM_HEAD, /**< LM Head Layer */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a big deal, but I think setting the LAYER_UNKNOWN at the end would be better.

LAYER_UNKNOWN = ML_TRAIN_LAYER_TYPE_UNKNOWN, /**< Unknown */
};

/**
Expand Down Expand Up @@ -442,6 +443,15 @@ Addition(const std::vector<std::string> &properties = {},
return createLayer(LayerType::LAYER_ADDITION, properties, compute_engine);
}

/**
* @brief Helper function to create lm_head layer
*/
inline std::unique_ptr<Layer>
LmHead(const std::vector<std::string> &properties = {},
const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) {
return createLayer(LayerType::LAYER_LM_HEAD, properties, compute_engine);
}

/**
* @brief Helper function to create concat layer
*/
Expand Down
1 change: 1 addition & 0 deletions api/nntrainer-api-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ typedef enum {
ML_TRAIN_LAYER_TYPE_TRANSPOSE = 36, /**< Transpose Layer type */
ML_TRAIN_LAYER_TYPE_CONV2D_TRANSPOSE =
37, /**< Convolution 2D Transpose Layer (Since 9.0) */
ML_TRAIN_LAYER_TYPE_LM_HEAD = 38, /**< LM Head Layer type */
ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP =
300, /**< Preprocess flip Layer (Since 6.5) */
ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE =
Expand Down
6 changes: 6 additions & 0 deletions nntrainer/cl_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* @see https://github.com/nnstreamer/nntrainer
* @author Debadri Samaddar <[email protected]>
* @author Niket Agarwal <[email protected]>
* @author Yash Singh <[email protected]>
* @author Thummala Pallavi <[email protected]>
* @bug No known bugs except for NYI items
* @brief This file contains app context related functions and classes that
Expand All @@ -20,6 +21,7 @@
#include <cl_context.h>
#include <concat_cl.h>
#include <fc_layer_cl.h>
#include <lm_head_layer_cl.h>
#include <reshape_cl.h>
#include <rmsnorm_layer_cl.h>
#include <swiglu_cl.h>
Expand Down Expand Up @@ -66,6 +68,10 @@ static void add_default_object(ClContext &cc) {
cc.registerFactory(nntrainer::createLayer<TransposeLayerCl>,
TransposeLayerCl::type,
ml::train::LayerType::LAYER_TRANSPOSE);

cc.registerFactory(nntrainer::createLayer<CustomLMHeadLayerCl>,
CustomLMHeadLayerCl::type,
ml::train::LayerType::LAYER_LM_HEAD);
}

static void registerer(ClContext &cc) noexcept {
Expand Down
116 changes: 116 additions & 0 deletions nntrainer/layers/cl_layers/custom_vocab_selection.cpp
EunjuYang marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 Hyeonseok Lee <[email protected]>
* Copyright (C) 2024 Yash Singh <[email protected]>
*
* @file custom_vocab_selection.cpp
* @date 1 Oct 2024
* @see https://github.com/nnstreamer/nntrainer
* @author Yash Singh <[email protected]>
* @bug No known bugs except for NYI items
* @brief Implementation of custom vocab selection
*/

#include <algorithm>
#include <custom_vocab_selection.h>

nntrainer::VocabSelection::VocabSelection(LshType lshType, int lshChoices,
int hiddenSize, int vocabCnt) :
lshType(lshType),
lshChoices(lshChoices),
vocabCnt(vocabCnt),
hiddenSize(hiddenSize),
lshBlockNum(0),
lshBits(0) {}

nntrainer::VocabSelection::~VocabSelection() {}

nntrainer::VocabSelectionNNTrainer::VocabSelectionNNTrainer(
LshType lshType, int lshChoices, int hiddenSize, int vocabCnt,
nntrainer::Tensor &weights) :
VocabSelection(lshType, lshChoices, hiddenSize, vocabCnt) {
this->lshBlockNum = (hiddenSize + lshBlockSize - 1) / lshBlockSize;
this->lshBits = lshBlockNum * lshBlockSize;
this->lshData = std::vector<lshDataBlock>(this->vocabCnt * lshBlockNum);

for (unsigned int i = 0; i < lshBlockNum; ++i) {
unsigned int actualSize =
std::min(lshBlockSize, hiddenSize - (int)i * lshBlockSize);
for (unsigned int j = 0; j < vocabCnt; ++j) {
lshDataBlock d;
for (unsigned int k = 0; k < actualSize; ++k) {
if (weights.getDataType() == nntrainer::TensorDim::DataType::FP32) {
d[k] = weights.getValue(0, 0, i * lshBlockSize + k, j) > 0 ? 1 : 0;
} else if (weights.getDataType() ==
nntrainer::TensorDim::DataType::FP16) {
d[k] =
weights.getValue<_FP16>(0, 0, i * lshBlockSize + k, j) > 0 ? 1 : 0;
}
}
for (unsigned int k = actualSize; k < lshBlockSize; ++k) {
d[k] = 0;
}
this->lshData[j * lshBlockNum + i] = d;
}
}
}

std::vector<std::vector<int>>
nntrainer::VocabSelectionNNTrainer::getVocabs(const nntrainer::Tensor &input) {
unsigned int batchSize = input.height();

std::vector<std::vector<int>> res = std::vector<std::vector<int>>(batchSize);
for (int i = 0; i < batchSize; i++) {
std::vector<lshDataBlock> d(lshBlockNum);
for (int k = 0; k < lshBlockNum; k++) {
int actualSize = std::min(lshBlockSize, hiddenSize - k * lshBlockSize);
for (int j = 0; j < actualSize; j++) {
if (input.getDataType() == nntrainer::TensorDim::DataType::FP32) {
d[k][j] = input.getValue(0, 0, i, j + k * lshBlockSize) >= 0 ? 1 : 0;
} else if (input.getDataType() ==
nntrainer::TensorDim::DataType::FP16) {
d[k][j] =
input.getValue<_FP16>(0, 0, i, j + k * lshBlockSize) >= 0 ? 1 : 0;
}
}
for (int j = actualSize; j < lshBlockSize; j++) {
d[k][j] = 0;
}
}
std::vector<int> simResult(vocabCnt, 0);
std::vector<int> simCount(lshBits + 1, 0);
for (int j = 0; j < vocabCnt; j++) {
for (int k = 0; k < lshBlockNum; k++) {
simResult[j] += (d[k] ^ lshData[j * lshBlockNum + k]).count();
}
simCount[simResult[j]]++;
}
int cut = lshBits + 1;
int leftover = 0;
int countSum = 0;
for (int j = 0; j <= lshBits; j++) {
countSum += simCount[j];
if (countSum > lshChoices) {
cut = j;
leftover = simCount[j] - (countSum - lshChoices);
break;
}
}
std::vector<int> selectedVocabs(lshChoices);
int pos = 0;
for (int j = 0; j < vocabCnt; j++) {
if (simResult[j] <= cut) {
if (simResult[j] < cut) {
selectedVocabs[pos] = j;
pos++;
} else if (leftover > 0) {
selectedVocabs[pos] = j;
pos++;
leftover--;
}
}
}
res[i] = selectedVocabs;
}
return res;
}
102 changes: 102 additions & 0 deletions nntrainer/layers/cl_layers/custom_vocab_selection.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 Hyeonseok Lee <[email protected]>
* Copyright (C) 2024 Yash Singh <[email protected]>
*
* @file custom_vocab_selection.h
* @date 1 Oct 2024
* @see https://github.com/nnstreamer/nntrainer
* @author Yash Singh <[email protected]>
* @bug No known bugs except for NYI items
* @brief Implementation of custom vocab selection
*/

#ifndef VOCAB_SELECTION_H
#define VOCAB_SELECTION_H

#include <tensor.h>

#ifndef LSH_BLOCK_SIZE
#define LSH_BLOCK_SIZE 256
#endif

using namespace std;

namespace nntrainer {

/**
* @brief Enumeration for different types of LSH algorithms used in vocab
* selection
*
*/
enum LshType { NONE = 0, SIMHASH = 1, ORTHOSIMHASH = 2 };
typedef std::bitset<LSH_BLOCK_SIZE> lshDataBlock;

/**
* @brief Vocab Selection class to select the vocabs from model output using LSH
*
*/
class VocabSelection {
protected:
int hiddenSize;
int vocabCnt;
const int lshBlockSize = LSH_BLOCK_SIZE;
int lshBlockNum;
int lshBits; // lshBlockSize * lshBlockNum
int lshChoices;
LshType lshType;
std::vector<lshDataBlock> lshData;

public:
/**
* @brief Constructor of VocabSelection class
*
*/
VocabSelection(LshType lshType, int lshChoices, int hiddenSize, int vocabCnt);
virtual std::
vector<std::vector<int>>

/**
* @brief Get the Vocabs object
*/
getVocabs(const nntrainer::Tensor &modelOutput) = 0;

/**
* @brief Destructor of VocabSelection class
*/
~VocabSelection();
};

/**
* @brief Vocab Selection NNTrainer class to select the vocabs from model output
* using LSH
*
*/
class VocabSelectionNNTrainer : public VocabSelection {
protected:
nntrainer::Tensor lshWeight;

public:
/**
* @brief Constructor of VocabSelectionNNTrainer class
*/
VocabSelectionNNTrainer(LshType lshType, int lshChoices, int hiddenSize,
int vocabCnt, nntrainer::Tensor &weights);
virtual std::
vector<std::vector<int>>

/**
* @brief Get the Vocabs object
*
*/
getVocabs(const nntrainer::Tensor &modelOutput);

/**
* @brief Destructor of VocabSelectionNNTrainer class
*/
~VocabSelectionNNTrainer(){};
};

} // namespace nntrainer

#endif
Loading
Loading