From 0ee82e033d36c79776af6514e4e52bb49b0a035b Mon Sep 17 00:00:00 2001 From: Donghak PARK Date: Tue, 8 Oct 2024 16:44:28 +0900 Subject: [PATCH] [Mixed Precision] Fix gradient clipping logic update mixed precision - gradient clipping logic - when gradient clipping, gradient should unscale before calc l2norm **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghak PARK --- nntrainer/graph/network_graph.cpp | 11 ++++++++--- nntrainer/graph/network_graph.h | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 80741751fa..b7c2b52a92 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -474,11 +474,15 @@ bool NetworkGraph::backwarding( Tensor global_norm_t( TensorDim({1u, 1u, 1u, (unsigned int)lazy_weights.size()})); float *global_norm_data = global_norm_t.getData(); + for (unsigned int idx = 0; idx < lazy_weights.size(); idx++) { auto const &w = lazy_weights[idx]; - if (w->getGradientRef().getDataType() != TensorDim::DataType::FP32) { - Tensor grad_32 = w->getGradientRef().clone(TensorDim::DataType::FP32); - global_norm_data[idx] = grad_32.l2norm(); + + if (isMixedPrecision()) { + Tensor scaled_grad = + w->getGradientRef().clone(TensorDim::DataType::FP32); + scaled_grad.divide_i(loss_scale); + global_norm_data[idx] = scaled_grad.l2norm(); } else { global_norm_data[idx] = w->getGradientNorm(); } @@ -1567,6 +1571,7 @@ void NetworkGraph::requestOptimizerVariable( } void NetworkGraph::resetLossScale(float scale) { + loss_scale = scale; for (auto iter = cbegin(); iter != cend(); iter++) { auto &ln = *iter; ln->getRunContext().setLossScale(scale); diff --git a/nntrainer/graph/network_graph.h b/nntrainer/graph/network_graph.h index 38f61e21af..05aeae9193 100644 --- a/nntrainer/graph/network_graph.h +++ b/nntrainer/graph/network_graph.h @@ -508,7 +508,7 @@ class NetworkGraph { lazy_weights; /**< weights with delayed grad update, e.g., gradient clipping, loss scaling */ bool is_clip_grad; - + float loss_scale; unsigned int nan_count; /**