-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GPU/OpenCL/Update] Initial version of LM Head layer with OpencCl ops…
… and Update Addition Layer on GPU with latest Pipeline changes Added initial version of LM head layer fpr GPU and removed dependencies of cl_context for addition_layer. Signed-off-by: Yash Singh <[email protected]>
- Loading branch information
1 parent
2643cbf
commit 282d797
Showing
13 changed files
with
770 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
* @see https://github.com/nnstreamer/nntrainer | ||
* @author Debadri Samaddar <[email protected]> | ||
* @author Niket Agarwal <[email protected]> | ||
* @author Yash Singh <[email protected]> | ||
* @author Thummala Pallavi <[email protected]> | ||
* @bug No known bugs except for NYI items | ||
* @brief This file contains app context related functions and classes that | ||
|
@@ -20,6 +21,7 @@ | |
#include <cl_context.h> | ||
#include <concat_cl.h> | ||
#include <fc_layer_cl.h> | ||
#include <lm_head_layer_cl.h> | ||
#include <reshape_cl.h> | ||
#include <rmsnorm_layer_cl.h> | ||
#include <swiglu_cl.h> | ||
|
@@ -66,6 +68,10 @@ static void add_default_object(ClContext &cc) { | |
cc.registerFactory(nntrainer::createLayer<TransposeLayerCl>, | ||
TransposeLayerCl::type, | ||
ml::train::LayerType::LAYER_TRANSPOSE); | ||
|
||
cc.registerFactory(nntrainer::createLayer<CustomLMHeadLayerCl>, | ||
CustomLMHeadLayerCl::type, | ||
ml::train::LayerType::LAYER_LM_HEAD); | ||
} | ||
|
||
static void registerer(ClContext &cc) noexcept { | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
// SPDX-License-Identifier: Apache-2.0 | ||
/** | ||
* Copyright (C) 2024 Yash Singh <[email protected]>> | ||
* | ||
* @file custom_vocab_selection.cpp | ||
* @date 1 Oct 2024 | ||
* @see https://github.com/nnstreamer/nntrainer | ||
* @author Yash Singh <[email protected]> | ||
* @bug No known bugs except for NYI items | ||
* @brief Implementation of custom vocab selection | ||
*/ | ||
|
||
#include <algorithm> | ||
#include <custom_vocab_selection.h> | ||
|
||
nntrainer::VocabSelection::VocabSelection(LshType lshType, int lshChoices, | ||
int hiddenSize, int vocabCnt) : | ||
lshType(lshType), | ||
lshChoices(lshChoices), | ||
vocabCnt(vocabCnt), | ||
hiddenSize(hiddenSize), | ||
lshBlockNum(0), | ||
lshBits(0) {} | ||
|
||
nntrainer::VocabSelection::~VocabSelection() {} | ||
|
||
nntrainer::VocabSelectionNNTrainer::VocabSelectionNNTrainer( | ||
LshType lshType, int lshChoices, int hiddenSize, int vocabCnt, | ||
nntrainer::Tensor &weights) : | ||
VocabSelection(lshType, lshChoices, hiddenSize, vocabCnt) { | ||
this->lshBlockNum = (hiddenSize + lshBlockSize - 1) / lshBlockSize; | ||
this->lshBits = lshBlockNum * lshBlockSize; | ||
this->lshData = std::vector<lshDataBlock>(this->vocabCnt * lshBlockNum); | ||
|
||
// for (unsigned int i = 0; i < vocabCnt; ++i) { | ||
// for (unsigned int j = 0; j < lshBlockNum; ++j) { | ||
// unsigned int actualSize = std::min(lshBlockSize, hiddenSize - | ||
// (int)j * lshBlockSize); lshDataBlock d; for (unsigned int k = 0; k | ||
// < actualSize; ++k) { | ||
// d[k] = weights.getValue<_FP16>(0, 0, i, j * lshBlockSize + k) > | ||
// 0 ? 1 : 0; | ||
// } | ||
// for (unsigned int k = actualSize; k < lshBlockSize; ++k) { | ||
// d[k] = 0; | ||
// } | ||
// this->lshData[i * lshBlockNum + j] = d; | ||
// } | ||
// } | ||
|
||
for (unsigned int i = 0; i < lshBlockNum; ++i) { | ||
unsigned int actualSize = | ||
std::min(lshBlockSize, hiddenSize - (int)i * lshBlockSize); | ||
for (unsigned int j = 0; j < vocabCnt; ++j) { | ||
lshDataBlock d; | ||
for (unsigned int k = 0; k < actualSize; ++k) { | ||
if (weights.getDataType() == nntrainer::TensorDim::DataType::FP32) { | ||
d[k] = weights.getValue(0, 0, i * lshBlockSize + k, j) > 0 ? 1 : 0; | ||
} else if (weights.getDataType() == | ||
nntrainer::TensorDim::DataType::FP16) { | ||
d[k] = | ||
weights.getValue<_FP16>(0, 0, i * lshBlockSize + k, j) > 0 ? 1 : 0; | ||
} | ||
} | ||
for (unsigned int k = actualSize; k < lshBlockSize; ++k) { | ||
d[k] = 0; | ||
} | ||
this->lshData[j * lshBlockNum + i] = d; | ||
} | ||
} | ||
} | ||
|
||
std::vector<std::vector<int>> | ||
nntrainer::VocabSelectionNNTrainer::getVocabs(const nntrainer::Tensor &input) { | ||
unsigned int batchSize = input.height(); | ||
|
||
std::vector<std::vector<int>> res = std::vector<std::vector<int>>(batchSize); | ||
for (int i = 0; i < batchSize; i++) { | ||
std::vector<lshDataBlock> d(lshBlockNum); | ||
for (int k = 0; k < lshBlockNum; k++) { | ||
int actualSize = std::min(lshBlockSize, hiddenSize - k * lshBlockSize); | ||
for (int j = 0; j < actualSize; j++) { | ||
if (input.getDataType() == nntrainer::TensorDim::DataType::FP32) { | ||
d[k][j] = input.getValue(0, 0, i, j + k * lshBlockSize) >= 0 ? 1 : 0; | ||
} else if (input.getDataType() == | ||
nntrainer::TensorDim::DataType::FP16) { | ||
d[k][j] = | ||
input.getValue<_FP16>(0, 0, i, j + k * lshBlockSize) >= 0 ? 1 : 0; | ||
} | ||
} | ||
for (int j = actualSize; j < lshBlockSize; j++) { | ||
d[k][j] = 0; | ||
} | ||
} | ||
std::vector<int> simResult(vocabCnt, 0); | ||
std::vector<int> simCount(lshBits + 1, 0); | ||
for (int j = 0; j < vocabCnt; j++) { | ||
for (int k = 0; k < lshBlockNum; k++) { | ||
simResult[j] += (d[k] ^ lshData[j * lshBlockNum + k]).count(); | ||
} | ||
simCount[simResult[j]]++; | ||
} | ||
int cut = lshBits + 1; | ||
int leftover = 0; | ||
int countSum = 0; | ||
for (int j = 0; j <= lshBits; j++) { | ||
countSum += simCount[j]; | ||
if (countSum > lshChoices) { | ||
cut = j; | ||
leftover = simCount[j] - (countSum - lshChoices); | ||
break; | ||
} | ||
} | ||
std::vector<int> selectedVocabs(lshChoices); | ||
int pos = 0; | ||
for (int j = 0; j < vocabCnt; j++) { | ||
if (simResult[j] <= cut) { | ||
if (simResult[j] < cut) { | ||
selectedVocabs[pos] = j; | ||
pos++; | ||
} else if (leftover > 0) { | ||
selectedVocabs[pos] = j; | ||
pos++; | ||
leftover--; | ||
} | ||
} | ||
} | ||
res[i] = selectedVocabs; | ||
} | ||
return res; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// SPDX-License-Identifier: Apache-2.0 | ||
/** | ||
* Copyright (C) 2024 Yash Singh <[email protected]>> | ||
* | ||
* @file custom_vocab_selection.h | ||
* @date 1 Oct 2024 | ||
* @see https://github.com/nnstreamer/nntrainer | ||
* @author Yash Singh <[email protected]> | ||
* @bug No known bugs except for NYI items | ||
* @brief Implementation of custom vocab selection | ||
*/ | ||
|
||
#ifndef VOCAB_SELECTION_H | ||
#define VOCAB_SELECTION_H | ||
|
||
#include <tensor.h> | ||
|
||
#ifndef LSH_BLOCK_SIZE | ||
#define LSH_BLOCK_SIZE 256 | ||
#endif | ||
|
||
using namespace std; | ||
|
||
namespace nntrainer { | ||
|
||
/** | ||
* @brief Enumeration for different types of LSH algorithms used in vocab | ||
* selection | ||
* | ||
*/ | ||
enum LshType { NONE = 0, SIMHASH = 1, ORTHOSIMHASH = 2 }; | ||
typedef std::bitset<LSH_BLOCK_SIZE> lshDataBlock; | ||
|
||
/** | ||
* @brief Vocab Selection class to select the vocabs from model output using LSH | ||
* | ||
*/ | ||
class VocabSelection { | ||
protected: | ||
int hiddenSize; | ||
int vocabCnt; | ||
const int lshBlockSize = LSH_BLOCK_SIZE; | ||
int lshBlockNum; | ||
int lshBits; // lshBlockSize * lshBlockNum | ||
int lshChoices; | ||
LshType lshType; | ||
std::vector<lshDataBlock> lshData; | ||
|
||
public: | ||
/** | ||
* @brief Constructor of VocabSelection class | ||
* | ||
*/ | ||
VocabSelection(LshType lshType, int lshChoices, int hiddenSize, int vocabCnt); | ||
virtual std:: | ||
vector<std::vector<int>> | ||
|
||
/** | ||
* @brief Get the Vocabs object | ||
*/ | ||
getVocabs(const nntrainer::Tensor &modelOutput) = 0; | ||
|
||
/** | ||
* @brief Destructor of VocabSelection class | ||
*/ | ||
~VocabSelection(); | ||
}; | ||
|
||
/** | ||
* @brief Vocab Selection NNTrainer class to select the vocabs from model output | ||
* using LSH | ||
* | ||
*/ | ||
class VocabSelectionNNTrainer : public VocabSelection { | ||
protected: | ||
nntrainer::Tensor lshWeight; | ||
|
||
public: | ||
/** | ||
* @brief Constructor of VocabSelectionNNTrainer class | ||
*/ | ||
VocabSelectionNNTrainer(LshType lshType, int lshChoices, int hiddenSize, | ||
int vocabCnt, nntrainer::Tensor &weights); | ||
virtual std:: | ||
vector<std::vector<int>> | ||
|
||
/** | ||
* @brief Get the Vocabs object | ||
* | ||
*/ | ||
getVocabs(const nntrainer::Tensor &modelOutput); | ||
|
||
/** | ||
* @brief Destructor of VocabSelectionNNTrainer class | ||
*/ | ||
~VocabSelectionNNTrainer() {}; | ||
}; | ||
|
||
} // namespace nntrainer | ||
|
||
#endif |
Oops, something went wrong.