Skip to content

Commit

Permalink
[Mixed Precision] Fix gradient clipping logic
Browse files Browse the repository at this point in the history
update mixed precision - gradient clipping logic
- when gradient clipping, gradient should unscale before calc l2norm

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]>
  • Loading branch information
DonghakPark committed Nov 26, 2024
1 parent e194ea6 commit fba7c09
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
11 changes: 8 additions & 3 deletions nntrainer/graph/network_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,15 @@ bool NetworkGraph::backwarding(
Tensor global_norm_t(
TensorDim({1u, 1u, 1u, (unsigned int)lazy_weights.size()}));
float *global_norm_data = global_norm_t.getData();

for (unsigned int idx = 0; idx < lazy_weights.size(); idx++) {
auto const &w = lazy_weights[idx];
if (w->getGradientRef().getDataType() != TensorDim::DataType::FP32) {
Tensor grad_32 = w->getGradientRef().clone(TensorDim::DataType::FP32);
global_norm_data[idx] = grad_32.l2norm();

if (isMixedPrecision()) {
Tensor scaled_grad =
w->getGradientRef().clone(TensorDim::DataType::FP32);
scaled_grad.divide_i(loss_scale);
global_norm_data[idx] = scaled_grad.l2norm();
} else {
global_norm_data[idx] = w->getGradientNorm();
}
Expand Down Expand Up @@ -1567,6 +1571,7 @@ void NetworkGraph::requestOptimizerVariable(
}

void NetworkGraph::resetLossScale(float scale) {
loss_scale = scale;
for (auto iter = cbegin(); iter != cend(); iter++) {
auto &ln = *iter;
ln->getRunContext().setLossScale(scale);
Expand Down
2 changes: 1 addition & 1 deletion nntrainer/graph/network_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ class NetworkGraph {
lazy_weights; /**< weights with delayed grad update, e.g., gradient
clipping, loss scaling */
bool is_clip_grad;

float loss_scale;
unsigned int nan_count;

/**
Expand Down

0 comments on commit fba7c09

Please sign in to comment.