Skip to content

Commit

Permalink
[LoRA] revise type of LoraRank property & fix error in fc_layer
Browse files Browse the repository at this point in the history
- update type of LoraRank property : Property<int> -> PositiveIntegerProperty
- fix typo dot_batched_deriv_wrt_1 -> dot_deriv_wrt_1
- update code with add -> add_i
- apply clang-format

Signed-off-by: Eunju Yang <[email protected]>
  • Loading branch information
EunjuYang committed Feb 2, 2024
1 parent eba4257 commit 3b660f3
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 38 deletions.
12 changes: 6 additions & 6 deletions nntrainer/layers/common_properties.h
Original file line number Diff line number Diff line change
Expand Up @@ -1335,12 +1335,12 @@ class AverageAttentionWeight : public Property<bool> {
/**
* @brief LoRA rank property, it is used to set rank of LoRA weight.
* @details
*/
class LoraRank : public Property<int>{
public:
static constexpr const char *key =
"lora_rank"; /**< unique key to access */
using prop_tag = int_prop_tag;
*/
class LoraRank : public PositiveIntegerProperty {
public:
static constexpr const char *key = "lora_rank"; /**< unique key to access */
using prop_tag = uint_prop_tag; /**< property type */
;
};

/**
Expand Down
52 changes: 23 additions & 29 deletions nntrainer/layers/fc_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ enum LORAParams { loraA, loraB, loraW };

FullyConnectedLayer::FullyConnectedLayer() :
LayerImpl(),
fc_props(props::Unit(), props::LoraRank()){
fc_props(props::Unit(), props::LoraRank()) {
weight_idx.fill(std::numeric_limits<unsigned>::max());
}

Expand All @@ -58,7 +58,9 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {
auto &disable_bias = std::get<props::DisableBias>(*layer_impl_props);

auto unit = std::get<props::Unit>(fc_props).get();
auto lora_rank = (std::get<props::LoraRank>(fc_props).empty())? 0 : std::get<props::LoraRank>(fc_props).get();
auto lora_rank = (std::get<props::LoraRank>(fc_props).empty())
? 0
: std::get<props::LoraRank>(fc_props).get();

NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument)
<< "Fully connected layer takes only one input";
Expand Down Expand Up @@ -106,11 +108,12 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {
}

// create weights for LoRA
if(lora_rank){
if (lora_rank) {

/** set LoRA specifications */
// loraA is a row vector (in_dim.width, lora_rank), loraB is a col vector of weight (lora_rank, out_dim),
// where shape of (loraA @ loraB) = shape of (W)
// loraA is a row vector (in_dim.width, lora_rank), loraB is a col vector of
// weight (lora_rank, out_dim), where shape of (loraA @ loraB) = shape of
// (W)
TensorDim loraA_dim(
1, is_nchw ? 1 : lora_rank, is_nchw ? in_dim.width() : 1,
is_nchw ? lora_rank : in_dim.channel(),
Expand All @@ -133,10 +136,9 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {

// save weight_lora as a tensor to be updated by (loraA @ loraB)
lora_idx[LORAParams::loraW] = context.requestTensor(
weight_dim, "weight_lora", Tensor::Initializer::NONE, true,
weight_dim, "weight_lora", Tensor::Initializer::NONE, true,
TensorLifespan::ITERATION_LIFESPAN);
}

}

void FullyConnectedLayer::exportTo(
Expand All @@ -150,16 +152,13 @@ void FullyConnectedLayer::setProperty(const std::vector<std::string> &values) {
LayerImpl::setProperty(remain_props);
}

/**
* @brief forwarding function for lora.
* It would be called during `forwarding` of FC layer.
*/
void FullyConnectedLayer::forwarding_lora(RunLayerContext &context, Tensor &weight){
void FullyConnectedLayer::forwarding_lora(RunLayerContext &context,
Tensor &weight) {
Tensor &loraA = context.getWeight(lora_idx[LORAParams::loraA]);
Tensor &loraB = context.getWeight(lora_idx[LORAParams::loraB]);
Tensor &weight_lora = context.getTensor(lora_idx[LORAParams::loraW]);
loraA.dot(loraB, weight_lora); // weight_lora = loraA @ loraB
weight.add(weight_lora, weight); // weight += weight_lora
loraA.dot(loraB, weight_lora); // weight_lora = loraA @ loraB
weight.add_i(weight_lora);
}

void FullyConnectedLayer::forwarding(RunLayerContext &context, bool training) {
Expand All @@ -180,11 +179,11 @@ void FullyConnectedLayer::forwarding(RunLayerContext &context, bool training) {
context.getWeightObject(weight_idx[FCParams::weight]).getOutputAxis();

weight.dequantize(weight_, axis);
if(!std::get<props::LoraRank>(fc_props).empty())
if (!std::get<props::LoraRank>(fc_props).empty())
forwarding_lora(context, weight_);
input_.dot(weight_, hidden_, false, false);
} else {
if(!std::get<props::LoraRank>(fc_props).empty())
if (!std::get<props::LoraRank>(fc_props).empty())
forwarding_lora(context, weight);
input_.dot(weight, hidden_, false, false);
}
Expand Down Expand Up @@ -237,11 +236,6 @@ void FullyConnectedLayer::incremental_forwarding(RunLayerContext &context,
}
}

/**
* @note
* [note for LoRA] implicit calcDerivative is implicitly applied.
* The weight is already updated with the LoRA's (W = W + W_lora)
*/
void FullyConnectedLayer::calcDerivative(RunLayerContext &context) {
Tensor &weight = context.getWeight(weight_idx[FCParams::weight]);

Expand All @@ -254,7 +248,7 @@ void FullyConnectedLayer::calcDerivative(RunLayerContext &context) {
void FullyConnectedLayer::calcGradient(RunLayerContext &context) {

// (baseline) calcGradient
if(std::get<props::LoraRank>(fc_props).empty()){
if (std::get<props::LoraRank>(fc_props).empty()) {
Tensor &djdw = context.getWeightGrad(weight_idx[FCParams::weight]);

const Tensor &derivative_ = context.getIncomingDerivative(SINGLE_INOUT_IDX);
Expand All @@ -278,26 +272,26 @@ void FullyConnectedLayer::calcGradient(RunLayerContext &context) {
!context.isGradientFirstAccess(weight_idx[FCParams::weight]));
} else {
// (LoRA) calcGradient
Tensor &djdla = context.getWeightGrad(lora_idx[LORAParams::loraA]);
Tensor &djdlb = context.getWeightGrad(lora_idx[LORAParams::loraB]);
Tensor &djdlora_w = context.getTensorGrad(lora_idx[LORAParams::loraW]);
Tensor &djdla = context.getWeightGrad(lora_idx[LORAParams::loraA]);
Tensor &djdlb = context.getWeightGrad(lora_idx[LORAParams::loraB]);
Tensor &djdlora_w = context.getTensorGrad(lora_idx[LORAParams::loraW]);

const Tensor &derivative_ = context.getIncomingDerivative(SINGLE_INOUT_IDX);
Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
Tensor &lora_A = context.getWeight(lora_idx[LORAParams::loraA]);
Tensor &lora_B = context.getWeight(lora_idx[LORAParams::loraB]);

// (cf) forward
// input_.dot(lora_weight, hidden) : hidden = input @ lora_weight
// lora_A.dot(lora_B, lora_weight) : lora_weight = lora_A @ lora_B
// input_.dot(lora_weight, hidden) : hidden = input @ lora_weight
// lora_A.dot(lora_B, lora_weight) : lora_weight = lora_A @ lora_B
input_.dot_deriv_wrt_2(
djdlora_w, derivative_, false, false,
!context.isGradientFirstAccess(lora_idx[LORAParams::loraW]));
lora_A.dot_deriv_wrt_2(
djdlb, djdlora_w, false, false,
!context.isGradientFirstAccess(lora_idx[LORAParams::loraB]));
djdla.dot_batched_deriv_wrt_1(
lora_B, djdlora_w,false, false,
djdla.dot_deriv_wrt_1(
lora_B, djdlora_w, false, false,
!context.isGradientFirstAccess(lora_idx[LORAParams::loraA]));
}
}
Expand Down
14 changes: 11 additions & 3 deletions nntrainer/layers/fc_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class FullyConnectedLayer : public LayerImpl {

/**
* @copydoc Layer::calcGradient(RunLayerContext &context)
* @note
* [note for LoRA] implicit calcDerivative is implicitly applied.
* The weight is already updated with the LoRA's (W = W + W_lora)
*/
void calcGradient(RunLayerContext &context) override;

Expand Down Expand Up @@ -99,16 +102,21 @@ class FullyConnectedLayer : public LayerImpl {
* &value)
*/
void setProperty(const std::vector<std::string> &values) override;


/**
* @brief forwarding function for lora.
* It would be called during `forwarding` of FC layer.
*/
void forwarding_lora(RunLayerContext &context, Tensor &weight);

inline static const std::string type = "fully_connected";

private:
std::tuple<props::Unit, props::LoraRank>
fc_props; /**< fc layer properties : unit - number of output neurons, lora_rank : rank of lora (optional) */
fc_props; /**< fc layer properties : unit - number of output neurons,
lora_rank : rank of lora (optional) */
std::array<unsigned int, 2> weight_idx; /**< indices of the weights */
std::array<unsigned int, 2> lora_idx; /**< indices of the lora weights */
std::array<unsigned int, 2> lora_idx; /**< indices of the lora weights */
};
} // namespace nntrainer

Expand Down

0 comments on commit 3b660f3

Please sign in to comment.