-
Notifications
You must be signed in to change notification settings - Fork 75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support mixed precision training @open sesame 03/08 07:57 #2455
Changes from all commits
57ed97a
357a601
2fd9207
4f0447c
e0efd10
006c828
1ef1f72
6e89fe6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
* @date 19 Oct 2020 | ||
* @see https://github.com/nnstreamer/nntrainer | ||
* @author Jijoong Moon <[email protected]> | ||
* @author Jiho Chu <[email protected]> | ||
* @bug No known bugs except for NYI items | ||
* @brief This is Network Graph Class for Neural Network | ||
* | ||
|
@@ -85,6 +86,15 @@ int NetworkGraph::compile(const std::string &loss_type) { | |
status = checkCompiledGraph(); | ||
NN_RETURN_STATUS(); | ||
|
||
/** | ||
* @note It can be integrated with addLossLayer method | ||
* if it removes adding loss layer to the model directly. | ||
*/ | ||
for (auto iter = cbegin(); iter != cend(); iter++) { | ||
auto &ln = *iter; | ||
ln->setLossScale(loss_scale); | ||
} | ||
|
||
compiled = true; | ||
|
||
return status; | ||
|
@@ -353,10 +363,15 @@ sharedConstTensors NetworkGraph::forwarding( | |
bool training, | ||
std::function<void(std::shared_ptr<LayerNode>, bool)> forwarding_op, | ||
std::function<bool(void *userdata)> stop_cb, void *userdata) { | ||
|
||
for (auto w : clip_weights) { | ||
w->applyMaster(); | ||
} | ||
|
||
for (auto iter = cbegin(); iter != cend() && !stop_cb(userdata); iter++) { | ||
auto &ln = *iter; | ||
PROFILE_TIME_START(profile_keys.at(ln->getType())); | ||
forwarding_op(*iter, training); | ||
forwarding_op(ln, training); | ||
PROFILE_TIME_END(profile_keys.at(ln->getType())); | ||
} | ||
|
||
|
@@ -397,7 +412,7 @@ void NetworkGraph::backwarding( | |
int iteration, | ||
std::function<void(std::shared_ptr<LayerNode>, int)> &backwarding_op, | ||
std::function<void(Weight &, int)> &apply_grad_clip_op, | ||
std::function<bool(void *userdata)> stop_cb, void *userdata) const { | ||
std::function<bool(void *userdata)> stop_cb, void *userdata) { | ||
/** | ||
* last layer backwarding is run out of this loop | ||
*/ | ||
|
@@ -426,6 +441,60 @@ void NetworkGraph::backwarding( | |
if (clip_weights.empty()) | ||
return; | ||
|
||
/** | ||
* mixed precision trainging needs gradient clipping and loss scale, | ||
* cause all weights are updated with clipping option. | ||
* also, loss scale makes to avoid unexpected training result. | ||
*/ | ||
auto update_loss_scale = [&](float scale) { | ||
ml_logd("set loss scale = %f", scale); | ||
for (auto iter = cbegin(); iter != cend(); iter++) { | ||
auto &ln = *iter; | ||
ln->setLossScale(scale); | ||
} | ||
loss_scale = scale; | ||
}; | ||
|
||
auto check_weights = [](std::vector<Weight *> &weights) { | ||
bool valid = true; | ||
for (auto &w : weights) { | ||
auto grad = w->getGradient(); | ||
if (grad.checkDataValidation(false) == false) { | ||
grad.setZero(); | ||
valid = false; | ||
} | ||
} | ||
return valid; | ||
}; | ||
|
||
// check first layer's derivative is valid | ||
// loss scale is adjusted between 1.0f ~ 256.0f | ||
// @todo provide max scale property | ||
auto &ln = *(cbegin() + 1); | ||
if (loss_scale != 0.0f && !ln->getRunContext().validateDerivatives()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For performance matters, we should check Nan in Derivatives after computing derivatives to save latency by computing the iteration from that point. In this case, we do not need to compute forwarding again from the start. So.. how about implementing it after Line 431? |
||
// It will not apply train results if data is invalid | ||
float scale = loss_scale > 1.5f ? loss_scale - 0.5f : 1.0f; | ||
ml_logd( | ||
"Derivative validation failed. Skip applying gradient. loss_scale(%f)", | ||
scale); | ||
check_weights(clip_weights); | ||
update_loss_scale(scale); | ||
return; | ||
} else { | ||
for (unsigned int idx = 0; idx < clip_weights.size(); idx++) { | ||
auto const &w = clip_weights[idx]; | ||
w->applyScaler(loss_scale); | ||
|
||
if (!check_weights(clip_weights)) { | ||
float scale = loss_scale > 1.5f ? loss_scale - 0.5f : 1.0f; | ||
ml_loge("gradient validation failed. skip update. loss_scale(%f)", | ||
scale); | ||
update_loss_scale(scale); | ||
return; | ||
} | ||
} | ||
} | ||
|
||
/** calculate the global norm */ | ||
Tensor global_norm_t( | ||
TensorDim({1u, 1u, 1u, (unsigned int)clip_weights.size()})); | ||
|
@@ -434,6 +503,7 @@ void NetworkGraph::backwarding( | |
auto const &w = clip_weights[idx]; | ||
global_norm_data[idx] = w->getGradientNorm(); | ||
} | ||
|
||
float global_norm = global_norm_t.l2norm(); | ||
/** apply the gradient with the above global norm */ | ||
for (auto w : clip_weights) { | ||
|
@@ -443,6 +513,12 @@ void NetworkGraph::backwarding( | |
for (auto w : clip_weights) { | ||
apply_grad_clip_op(*w, iteration); | ||
} | ||
|
||
// update loss scale | ||
if (loss_scale != 0.0f) { | ||
float scale = loss_scale + 2.0f; | ||
update_loss_scale(scale); | ||
} | ||
} | ||
|
||
LayerNode *NetworkGraph::computeBackwardEnd() { | ||
|
@@ -605,6 +681,14 @@ NetworkGraph::canExecuteInPlace(const std::shared_ptr<LayerNode> &lnode) { | |
(lnode->getType() == LayerNormalizationLayer::type); | ||
}; | ||
|
||
/** | ||
* if the layer's input and output type is not FP32, then it cannot be | ||
* inplace. We assume that the input is always FP32. | ||
*/ | ||
if (lnode->getInputConnections().empty() && | ||
!istrequal(getTensorType()[2], "FP32")) | ||
return InPlace::NONE; | ||
|
||
/** | ||
* @note Conditions to decide if this layer node can be in-place: | ||
* 1. if the layer is a no-op, then it can operate in-place as it is not | ||
|
@@ -686,15 +770,6 @@ NetworkGraph::canExecuteInPlace(const std::shared_ptr<LayerNode> &lnode) { | |
return InPlace::RESTRICTING; | ||
} | ||
|
||
/** | ||
* if the layer's input and output type is not FP32, then it cannot be | ||
* inplace. We assume that the input is always FP32. | ||
*/ | ||
if (lnode->getInputConnections().empty()) { | ||
if (!istrequal(getTensorType()[2], "FP32")) | ||
return InPlace::NONE; | ||
} | ||
|
||
return InPlace::NONE; | ||
} | ||
|
||
|
@@ -876,7 +951,11 @@ NetworkGraph::finalizeContext(const std::shared_ptr<LayerNode> &lnode, | |
lnode->configureRunContext( | ||
// TODO: update weights spec for trainable based on layer trainable prop | ||
tensor_manager->requestWeights(gnode, init_context.getWeightsSpec(), | ||
lnode->getTrainable(), shared_weight_names), | ||
lnode->getTrainable(), shared_weight_names, | ||
init_context.getActivationDataType() != | ||
init_context.getWeightDataType() | ||
? init_context.getActivationDataType() | ||
: TensorDim::DataType::NONE), | ||
inputs, outputs, | ||
tensor_manager->requestTensors(gnode, init_context.getTensorsSpec(), | ||
lnode->getTrainable(), shared_tensor_names)); | ||
|
@@ -1551,13 +1630,25 @@ void NetworkGraph::flushCacheExcept(unsigned int order) { | |
void NetworkGraph::requestOptimizerVariable( | ||
std::function<std::vector<TensorDim>(const TensorDim &)> cb, | ||
bool request_only_trainable) { | ||
bool need_master = !istrequal(getTensorType()[1], getTensorType()[2]); | ||
for (auto const &w : tensor_manager->getWeights()) { | ||
if (w->isGradientLastAccess() && w->hasGradient()) { | ||
const TensorDim &dim = w->getDim(); | ||
std::vector<TensorDim> dims = cb(dim); | ||
w->setOptimizerVariables(tensor_manager->requestWeightOptimizerVariables( | ||
dims, w->getName(), TensorLifespan::MAX_LIFESPAN, | ||
w->isGradientClipByGlobalNorm(), Tensor::Initializer::ZEROS)); | ||
if (need_master) { | ||
for (auto &dim : dims) | ||
dim.setDataType( | ||
str_converter<enum_class_prop_tag, nntrainer::TensorDataTypeInfo>:: | ||
from_string(getTensorType()[1])); | ||
w->setOptimizerMasterVariables( | ||
tensor_manager->requestWeightOptimizerVariables( | ||
dims, w->getName(), TensorLifespan::MAX_LIFESPAN, | ||
w->isGradientClipByGlobalNorm(), Tensor::Initializer::ZEROS, | ||
need_master)); | ||
} | ||
} | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we have to enable gradient clip property also true to use mixed precision training. I guess, this PR doesn't consider the case which enabled mixed + gradient clip.