From 57ed97ac19c053b5d64ec00a73f421c18fe7a649 Mon Sep 17 00:00:00 2001 From: Jiho Chu Date: Wed, 6 Mar 2024 09:58:18 +0900 Subject: [PATCH 1/8] [Property] Add loss scale property It add loss scale property as model common property. Signed-off-by: Jiho Chu --- nntrainer/models/model_common_properties.cpp | 2 + nntrainer/models/model_common_properties.h | 11 ++++ nntrainer/models/neuralnet.cpp | 8 ++- nntrainer/models/neuralnet.h | 67 ++++++++++---------- 4 files changed, 52 insertions(+), 36 deletions(-) diff --git a/nntrainer/models/model_common_properties.cpp b/nntrainer/models/model_common_properties.cpp index a1f560c49a..984cad662a 100644 --- a/nntrainer/models/model_common_properties.cpp +++ b/nntrainer/models/model_common_properties.cpp @@ -39,4 +39,6 @@ MemorySwapLookahead::MemorySwapLookahead(const unsigned int &value) { ModelTensorDataType::ModelTensorDataType(ModelTensorDataTypeInfo::Enum value) { set(value); } +LossScale::LossScale(float value) { set(value); } + } // namespace nntrainer::props diff --git a/nntrainer/models/model_common_properties.h b/nntrainer/models/model_common_properties.h index 791f9ed5d3..3776afefca 100644 --- a/nntrainer/models/model_common_properties.h +++ b/nntrainer/models/model_common_properties.h @@ -211,6 +211,17 @@ class ModelTensorDataType final : public EnumProperty { ModelTensorDataTypeInfo::Enum::W32A32); }; +/** + * @brief LossScale property, loss is scaled by this value + * + */ +class LossScale : public Property { +public: + LossScale(float value = 0.0f); + static constexpr const char *key = "loss_scale"; /**< unique key to access */ + using prop_tag = float_prop_tag; /**< property type */ +}; + } // namespace nntrainer::props #endif diff --git a/nntrainer/models/neuralnet.cpp b/nntrainer/models/neuralnet.cpp index 348b3f48b1..bee6dd7a4b 100644 --- a/nntrainer/models/neuralnet.cpp +++ b/nntrainer/models/neuralnet.cpp @@ -70,7 +70,7 @@ NeuralNetwork::NeuralNetwork() : props::Epochs(), props::TrainingBatchSize(), props::SavePath(), props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(), props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead(), - props::TensorFormat(), props::ModelTensorDataType()), + props::TensorFormat(), props::ModelTensorDataType(), props::LossScale()), load_path(std::string()), epoch_idx(0), iter(0), @@ -88,7 +88,7 @@ NeuralNetwork::NeuralNetwork(AppContext app_context_) : props::Epochs(), props::TrainingBatchSize(), props::SavePath(), props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(), props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead(), - props::TensorFormat(), props::ModelTensorDataType()), + props::TensorFormat(), props::ModelTensorDataType(), props::LossScale()), load_path(std::string()), epoch_idx(0), iter(0), @@ -179,8 +179,9 @@ int NeuralNetwork::compile() { const std::string tensor_type = to_string(std::get(model_flex_props)); + const float loss_scale = std::get(model_flex_props); model_graph = NetworkGraph(memory_swap, memory_swap_path, lookahead, - tensor_format, tensor_type); + tensor_format, tensor_type, loss_scale); model_graph.setMemoryOptimizations( std::get(model_flex_props)); @@ -1017,6 +1018,7 @@ int NeuralNetwork::train_run( auto train_for_iteration = [this, stop_cb, stop_user_data](RunStats &stat, DataBuffer &buffer) { + ml_loge("train for iteration"); forwarding(true, stop_cb, stop_user_data); backwarding(iter++, stop_cb, stop_user_data); diff --git a/nntrainer/models/neuralnet.h b/nntrainer/models/neuralnet.h index 457b7d1e97..a2923ae8a7 100644 --- a/nntrainer/models/neuralnet.h +++ b/nntrainer/models/neuralnet.h @@ -221,10 +221,11 @@ class NeuralNetwork : public ml::train::Model { /** * @brief Forward Propagation of the neural network */ - sharedConstTensors forwarding(bool training = true, - std::function stop_cb = - [](void *user_data) { return false; }, - void *user_data = nullptr); + sharedConstTensors forwarding( + bool training = true, + std::function stop_cb = + [](void *user_data) { return false; }, + void *user_data = nullptr); /** * @brief Forward Propagation of the neural network @@ -239,12 +240,11 @@ class NeuralNetwork : public ml::train::Model { /** * @brief Incremental forward Propagation of the neural network */ - sharedConstTensors - incremental_forwarding(unsigned int from, unsigned int to, - bool training = true, - std::function stop_cb = - [](void *user_data) { return false; }, - void *user_data = nullptr); + sharedConstTensors incremental_forwarding( + unsigned int from, unsigned int to, bool training = true, + std::function stop_cb = + [](void *user_data) { return false; }, + void *user_data = nullptr); /** * @brief Incremental forward Propagation of the neural network @@ -261,10 +261,11 @@ class NeuralNetwork : public ml::train::Model { * @brief Backward Propagation of the neural network * @param[in] iteration Iteration Number for the optimizer */ - void backwarding(int iteration, - std::function stop_cb = - [](void *user_data) { return false; }, - void *user_data = nullptr); + void backwarding( + int iteration, + std::function stop_cb = + [](void *user_data) { return false; }, + void *user_data = nullptr); /** * @copydoc Model::save(const std::string &file_path, ml::train::ModelFormat @@ -329,13 +330,14 @@ class NeuralNetwork : public ml::train::Model { * @retval #ML_ERROR_NONE Successful. * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ - int train(const std::vector &values = {}, - std::function stop_cb = - [](void *stop_user_data) { return false; }, - void *stop_user_data = nullptr, - std::function epoch_complete_cb = - [](void *epoch_user_data) { return false; }, - void *epoch_user_data = nullptr) override; + int train( + const std::vector &values = {}, + std::function stop_cb = + [](void *stop_user_data) { return false; }, + void *stop_user_data = nullptr, + std::function epoch_complete_cb = + [](void *epoch_user_data) { return false; }, + void *epoch_user_data = nullptr) override; /** * @brief Run NeuralNetwork inference @@ -622,12 +624,11 @@ s * @retval shared_ptr const std::string file_path) override; private: - using FlexiblePropTypes = - std::tuple; + using FlexiblePropTypes = std::tuple< + props::Epochs, props::TrainingBatchSize, props::SavePath, + props::ContinueTrain, props::SaveBestPath, props::MemoryOptimization, + props::MemorySwap, props::MemorySwapPath, props::MemorySwapLookahead, + props::TensorFormat, props::ModelTensorDataType, props::LossScale>; using RigidPropTypes = std::tuple, std::vector, props::ClipGradByGlobalNorm>; @@ -709,12 +710,12 @@ s * @retval shared_ptr * @retval #ML_ERROR_NONE Successful. * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ - int train_run(std::function stop_cb = - [](void *) { return false; }, - void *user_data = nullptr, - std::function epoch_complete_cb = - [](void *) { return false; }, - void *data = nullptr); + int train_run( + std::function stop_cb = [](void *) { return false; }, + void *user_data = nullptr, + std::function epoch_complete_cb = + [](void *) { return false; }, + void *data = nullptr); /** * @brief Swap function for the class From 357a6013a0387e4955385156d8737a5fe6104ef7 Mon Sep 17 00:00:00 2001 From: Jiho Chu Date: Wed, 6 Mar 2024 10:00:01 +0900 Subject: [PATCH 2/8] [Graph] Use loss scale property It checks derivative validation after backwarding, and apply gradient if derivative validation success. Signed-off-by: Jiho Chu --- nntrainer/graph/network_graph.cpp | 100 ++++++++++++++++++++++++++---- nntrainer/graph/network_graph.h | 13 ++-- 2 files changed, 97 insertions(+), 16 deletions(-) diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 68f5dc6c72..23eaa3afff 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -6,6 +6,7 @@ * @date 19 Oct 2020 * @see https://github.com/nnstreamer/nntrainer * @author Jijoong Moon + * @author Jiho Chu * @bug No known bugs except for NYI items * @brief This is Network Graph Class for Neural Network * @@ -85,6 +86,14 @@ int NetworkGraph::compile(const std::string &loss_type) { status = checkCompiledGraph(); NN_RETURN_STATUS(); + /* @note It can be integrated with addLossLayer method + * if it removes adding loss layer to the model directly. + */ + for (auto iter = cbegin(); iter != cend(); iter++) { + auto &ln = *iter; + ln->setLossScale(loss_scale); + } + compiled = true; return status; @@ -353,10 +362,15 @@ sharedConstTensors NetworkGraph::forwarding( bool training, std::function, bool)> forwarding_op, std::function stop_cb, void *userdata) { + + for (auto w : clip_weights) { + w->applyMaster(); + } + for (auto iter = cbegin(); iter != cend() && !stop_cb(userdata); iter++) { auto &ln = *iter; PROFILE_TIME_START(profile_keys.at(ln->getType())); - forwarding_op(*iter, training); + forwarding_op(ln, training); PROFILE_TIME_END(profile_keys.at(ln->getType())); } @@ -397,7 +411,7 @@ void NetworkGraph::backwarding( int iteration, std::function, int)> &backwarding_op, std::function &apply_grad_clip_op, - std::function stop_cb, void *userdata) const { + std::function stop_cb, void *userdata) { /** * last layer backwarding is run out of this loop */ @@ -426,6 +440,46 @@ void NetworkGraph::backwarding( if (clip_weights.empty()) return; + /** + * mixed precision trainging needs gradient clipping and loss scale, + * cause all weights are updated with clipping option. + * also, loss scale makes to avoid unexpected training result. + */ + auto update_loss_scale = [&](float scale) { + ml_logd("set loss scale = %f", scale); + for (auto iter = cbegin(); iter != cend(); iter++) { + auto &ln = *iter; + ln->setLossScale(scale); + } + loss_scale = scale; + }; + + // check first layer's derivative is valid + // loss scale is adjusted between 1.0f ~ 256.0f + // @TODO provide max scale property + auto &ln = *(cbegin() + 1); + if (loss_scale != 0.0f && !ln->getRunContext().validateDerivatives()) { + // It will not apply train results if data is invalid + float scale = loss_scale > 1.5f ? loss_scale - 0.5f : 1.0f; + ml_logd( + "Derivative validation failed. Skip applying gradient. loss_scale(%f)", + scale); + update_loss_scale(scale); + return; + } else { + for (unsigned int idx = 0; idx < clip_weights.size(); idx++) { + auto const &w = clip_weights[idx]; + w->applyScaler(loss_scale); + if (w->getGradient().checkDataValidation(false) == false) { + float scale = loss_scale > 1.5f ? loss_scale - 0.5f : 1.0f; + ml_loge("gradient validation failed. skip update. loss_scale(%f)", + scale); + update_loss_scale(scale); + return; + } + } + } + /** calculate the global norm */ Tensor global_norm_t( TensorDim({1u, 1u, 1u, (unsigned int)clip_weights.size()})); @@ -434,6 +488,7 @@ void NetworkGraph::backwarding( auto const &w = clip_weights[idx]; global_norm_data[idx] = w->getGradientNorm(); } + float global_norm = global_norm_t.l2norm(); /** apply the gradient with the above global norm */ for (auto w : clip_weights) { @@ -443,6 +498,12 @@ void NetworkGraph::backwarding( for (auto w : clip_weights) { apply_grad_clip_op(*w, iteration); } + + // update loss scale + if (loss_scale != 0.0f) { + float scale = loss_scale + 2.0f; + update_loss_scale(scale); + } } LayerNode *NetworkGraph::computeBackwardEnd() { @@ -605,6 +666,14 @@ NetworkGraph::canExecuteInPlace(const std::shared_ptr &lnode) { (lnode->getType() == LayerNormalizationLayer::type); }; + /** + * if the layer's input and output type is not FP32, then it cannot be + * inplace. We assume that the input is always FP32. + */ + if (lnode->getInputConnections().empty() && + !istrequal(getTensorType()[2], "FP32")) + return InPlace::NONE; + /** * @note Conditions to decide if this layer node can be in-place: * 1. if the layer is a no-op, then it can operate in-place as it is not @@ -686,15 +755,6 @@ NetworkGraph::canExecuteInPlace(const std::shared_ptr &lnode) { return InPlace::RESTRICTING; } - /** - * if the layer's input and output type is not FP32, then it cannot be - * inplace. We assume that the input is always FP32. - */ - if (lnode->getInputConnections().empty()) { - if (!istrequal(getTensorType()[2], "FP32")) - return InPlace::NONE; - } - return InPlace::NONE; } @@ -876,7 +936,11 @@ NetworkGraph::finalizeContext(const std::shared_ptr &lnode, lnode->configureRunContext( // TODO: update weights spec for trainable based on layer trainable prop tensor_manager->requestWeights(gnode, init_context.getWeightsSpec(), - lnode->getTrainable(), shared_weight_names), + lnode->getTrainable(), shared_weight_names, + init_context.getActivationDataType() != + init_context.getWeightDataType() + ? init_context.getActivationDataType() + : TensorDim::DataType::NONE), inputs, outputs, tensor_manager->requestTensors(gnode, init_context.getTensorsSpec(), lnode->getTrainable(), shared_tensor_names)); @@ -1551,6 +1615,7 @@ void NetworkGraph::flushCacheExcept(unsigned int order) { void NetworkGraph::requestOptimizerVariable( std::function(const TensorDim &)> cb, bool request_only_trainable) { + bool need_master = !istrequal(getTensorType()[1], getTensorType()[2]); for (auto const &w : tensor_manager->getWeights()) { if (w->isGradientLastAccess() && w->hasGradient()) { const TensorDim &dim = w->getDim(); @@ -1558,6 +1623,17 @@ void NetworkGraph::requestOptimizerVariable( w->setOptimizerVariables(tensor_manager->requestWeightOptimizerVariables( dims, w->getName(), TensorLifespan::MAX_LIFESPAN, w->isGradientClipByGlobalNorm(), Tensor::Initializer::ZEROS)); + if (need_master) { + for (auto &dim : dims) + dim.setDataType( + str_converter:: + from_string(getTensorType()[1])); + w->setOptimizerMasterVariables( + tensor_manager->requestWeightOptimizerVariables( + dims, w->getName(), TensorLifespan::MAX_LIFESPAN, + w->isGradientClipByGlobalNorm(), Tensor::Initializer::ZEROS, + need_master)); + } } } } diff --git a/nntrainer/graph/network_graph.h b/nntrainer/graph/network_graph.h index 5c9adf0363..c209d3a65a 100644 --- a/nntrainer/graph/network_graph.h +++ b/nntrainer/graph/network_graph.h @@ -51,7 +51,8 @@ class NetworkGraph { optimize_memory(true), exec_mode(ExecutionMode::TRAIN), tensor_format("NCHW"), - tensor_dtype(split("FP32-FP32", getRegex("\\-"))) {} + tensor_dtype(split("FP32-FP32", getRegex("\\-"))), + loss_scale(0.0f) {} /** * @brief Constructor of NeuralNetwork Graph Class @@ -61,7 +62,8 @@ class NetworkGraph { NetworkGraph(bool enable_swap, const std::string &swap_path = "", unsigned int lookahead = 0, const std::string &tensor_format_ = "NCHW", - const std::string &tensor_dtype_ = "FP32-FP32") : + const std::string &tensor_dtype_ = "FP32-FP32", + const float scale = 0.0f) : tensor_manager(std::make_shared(enable_swap, swap_path, lookahead, tensor_format_, tensor_dtype_)), graph(), @@ -73,7 +75,8 @@ class NetworkGraph { optimize_memory(true), exec_mode(ExecutionMode::TRAIN), tensor_format(tensor_format_), - tensor_dtype(split(tensor_dtype_, getRegex("\\-"))) {} + tensor_dtype(split(tensor_dtype_, getRegex("\\-"))), + loss_scale(scale) {} /** * @brief Destructor of the NeuralNetwork Graph class @@ -212,7 +215,7 @@ class NetworkGraph { std::function &apply_grad_clip_op, std::function stop_cb = [](void *user_data) { return false; }, - void *user_data = nullptr) const; + void *user_data = nullptr); /** * @brief get begin iterator for the graph @@ -482,6 +485,8 @@ class NetworkGraph { std::vector clip_weights; /**< weights with global norm based clipping enabled */ + float loss_scale; /**< loss scale factor for the graph */ + /** * @brief topological sort * @param[in] ith index of LayerNode From 2fd9207e3c51f9e6cff337aec36b585da5e6c4db Mon Sep 17 00:00:00 2001 From: Jiho Chu Date: Wed, 6 Mar 2024 10:04:59 +0900 Subject: [PATCH 3/8] [Tensor] Add several methods for mixed precision clone method with tensor type is added for creating tensor with differenct datatype. And, some convenient methods for loss scale is added. Signed-off-by: Jiho Chu --- api/ccapi/include/tensor_dim.h | 7 ++-- nntrainer/tensor/manager.cpp | 70 +++++++++++++++++++++++--------- nntrainer/tensor/manager.h | 13 ++++-- nntrainer/tensor/memory_pool.cpp | 2 + nntrainer/tensor/tensor.cpp | 17 ++++++++ nntrainer/tensor/tensor.h | 55 +++++++++++++++++++++++++ nntrainer/tensor/weight.cpp | 17 +++++++- nntrainer/tensor/weight.h | 64 ++++++++++++++++++++++++++--- 8 files changed, 213 insertions(+), 32 deletions(-) diff --git a/api/ccapi/include/tensor_dim.h b/api/ccapi/include/tensor_dim.h index 64523618c1..7cded4806f 100644 --- a/api/ccapi/include/tensor_dim.h +++ b/api/ccapi/include/tensor_dim.h @@ -55,7 +55,8 @@ class TensorDim { QINT4, /** quantized int 4*/ QINT8, /** quantized int 8*/ FP16, /** half precision */ - FP32 /** single precision */ + FP32, /** single precision */ + NONE, /** not specified */ }; /** @@ -97,9 +98,7 @@ class TensorDim { */ TensorType(Format fm, DataType d_type, StorageOrder order = StorageOrder::ROW_MAJOR) : - format(fm), - data_type(d_type), - storage_order(order){}; + format(fm), data_type(d_type), storage_order(order){}; }; /** diff --git a/nntrainer/tensor/manager.cpp b/nntrainer/tensor/manager.cpp index 4178330ebd..73d4a2bc30 100644 --- a/nntrainer/tensor/manager.cpp +++ b/nntrainer/tensor/manager.cpp @@ -9,10 +9,12 @@ * @see https://github.com/nnstreamer/nntrainer * @author Parichay Kapoor * @author Jihoon Lee + * @author Jiho Chu * @bug No known bugs except for NYI items * */ +#include "dataset.h" #ifdef __ANDROID__ #include #endif @@ -52,10 +54,7 @@ namespace nntrainer { MMapedMemory::MMapedMemory(size_t size, bool allocate_fd_) : - fd(-1), - buf(nullptr), - buf_size(0), - allocate_fd(allocate_fd_) { + fd(-1), buf(nullptr), buf_size(0), allocate_fd(allocate_fd_) { #ifndef __ANDROID__ if (allocate_fd) { @@ -148,13 +147,20 @@ void Manager::reinitialize() { } void Manager::allocateWeights(unsigned int max_exec_order_) { + if (!weight_master_pool.isAllocated()) { + finalizeTensorPool(weight_master_pool, 0, max_exec_order_); + weight_master_pool.allocate(); + } if (!weight_pool.isAllocated()) { finalizeTensorPool(weight_pool, 0, max_exec_order_); weight_pool.allocate(); } } -void Manager::deallocateWeights() { weight_pool.deallocate(); } +void Manager::deallocateWeights() { + weight_pool.deallocate(); + weight_master_pool.deallocate(); +} static Tensor *requestTensor_(const TensorSpecV2 &spec, const GraphNode::ExecutionOrder &exec_order, @@ -366,7 +372,8 @@ void Manager::initializeTensorsTrain(unsigned int max_exec_order_) { */ std::vector Manager::requestWeights( const GraphNode &node, const std::vector &weights_spec, - bool trainable, const std::vector &shared_names) { + bool trainable, const std::vector &shared_names, + TensorDim::DataType act_type) { const auto [forwarding_order, calcGradient_order, calcDerivative_order, applyGradient_order] = node.getExecutionOrder(); @@ -416,14 +423,23 @@ std::vector Manager::requestWeights( // var_exec_order.push_back(TensorPool::PERSIST_END_ORDER); } - Tensor *var = nullptr, *grad = nullptr; + Tensor *var = nullptr, *grad = nullptr, *var_m = nullptr; bool is_dependent = !shared_names.empty(); + TensorDim dim_a = dim; if (is_dependent) { /// shared_name is used and the orignal name is discarded const auto &shared_name = shared_names.at(i); /** case when shared names are given */ - var = weight_pool.requestOrExtend(shared_name, dim, var_exec_order, - var_ls, t_initializer); + if (act_type == TensorDim::DataType::NONE) { + var = weight_pool.requestOrExtend(shared_name, dim_a, var_exec_order, + var_ls, t_initializer); + } else { + dim_a.setDataType(act_type); + var = weight_pool.requestOrExtend(shared_name, dim_a, var_exec_order, + var_ls, t_initializer); + var_m = weight_master_pool.requestOrExtend( + shared_name, dim, var_exec_order, var_ls, t_initializer); + } if (trainable && need_gradient) { /** We cannot use the tensor schedulding for weight gradient if the @@ -431,13 +447,21 @@ std::vector Manager::requestWeights( * for each layer anymore and it is hard to overwritten. */ grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix, - dim, grad_exec_order, grad_ls, + dim_a, grad_exec_order, grad_ls, Tensor::Initializer::ZEROS); } } else { /** case requesting fresh weights */ - var = - weight_pool.request(name, dim, var_exec_order, var_ls, t_initializer); + if (act_type == TensorDim::DataType::NONE) { + var = weight_pool.request(name, dim_a, var_exec_order, var_ls, + t_initializer); + } else { + dim_a.setDataType(act_type); + var = weight_pool.request(name, dim_a, var_exec_order, var_ls, + t_initializer); + var_m = weight_master_pool.request(name, dim, var_exec_order, var_ls, + t_initializer); + } if (trainable && need_gradient) { /** is_wgrad is the index which is true when it is the gradient tensor @@ -447,14 +471,15 @@ std::vector Manager::requestWeights( bool is_wgrad = true; if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm)) is_wgrad = false; - grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim, + grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim_a, grad_exec_order, grad_ls, Tensor::Initializer::ZEROS, is_wgrad); } } - weights_v2.emplace_back(std::make_unique( - var, grad, w_reg, w_reg_const, decay, is_dependent, clip_by_global_norm)); + weights_v2.emplace_back( + std::make_unique(var, grad, w_reg, w_reg_const, decay, + is_dependent, clip_by_global_norm, 3, var_m)); } std::transform(weights_v2.begin() + current_size, weights_v2.end(), @@ -671,7 +696,7 @@ bool Manager::isSecondLastAccess(const std::string &name, std::vector Manager::requestWeightOptimizerVariables( const std::vector &dims, const std::string &name, const TensorLifespan &lifespan, bool is_grad_clip, - Tensor::Initializer initializer) { + Tensor::Initializer initializer, bool is_master) { std::vector ret; ret.reserve(dims.size()); @@ -686,9 +711,16 @@ std::vector Manager::requestWeightOptimizerVariables( /// @note this is assuming weight optimizer variables is treated as weight, if /// not, there is room to optimize below behavior - for (unsigned int idx = 0; idx < dims.size(); idx++) - ret.push_back(weight_pool.request(name + ":opt" + std::to_string(idx), - dims[idx], exec, lifespan, initializer)); + for (unsigned int idx = 0; idx < dims.size(); idx++) { + if (is_master) + ret.push_back( + weight_master_pool.request(name + ":opt" + std::to_string(idx), + dims[idx], exec, lifespan, initializer)); + else + ret.push_back(weight_pool.request(name + ":opt" + std::to_string(idx), + dims[idx], exec, lifespan, + initializer)); + } return ret; } diff --git a/nntrainer/tensor/manager.h b/nntrainer/tensor/manager.h index ab1c018153..2e910f93c2 100644 --- a/nntrainer/tensor/manager.h +++ b/nntrainer/tensor/manager.h @@ -145,6 +145,8 @@ class Manager { unsigned int lookahead = 0, const std::string tensor_format_ = "NCHW", const std::string tensor_dtype_ = "FP32-FP32") : weight_pool(enable_swap, swap_path, "weight_pool"), + /* @todo weight master does not support cache pool yet */ + weight_master_pool(enable_swap, swap_path, "weight_master_pool"), tensor_pool(enable_swap, swap_path, "tensor_pool"), enable_optimizations(true), swap_lookahead(lookahead), @@ -191,14 +193,16 @@ class Manager { * @param weights_spec Specification for the weights * @param trainable make the weight trainable if true * @param shared_names name to refer to when the weights are borrowed from the - * original source. if not shared pass empty vector + * @param act_type activation data type if data type is different from weight, + * otherwise NONE. original source. if not shared pass empty vector * * @return created weights list */ std::vector requestWeights(const GraphNode &node, const std::vector &weights_spec, bool trainable, - const std::vector &shared_names); + const std::vector &shared_names, + TensorDim::DataType act_type = TensorDim::DataType::NONE); /** * @brief Create tensors with the given spec @@ -225,7 +229,8 @@ class Manager { std::vector requestWeightOptimizerVariables( const std::vector &dims, const std::string &name, const TensorLifespan &lifespan, bool is_grad_clip, - Tensor::Initializer initializer = Tensor::Initializer::NONE); + Tensor::Initializer initializer = Tensor::Initializer::NONE, + bool is_master = false); /** * @brief Create tensors with the given spec @@ -509,6 +514,8 @@ class Manager { tensor_book; /**< reference to tensor book kept */ TensorPool weight_pool; /**< tensor pool to request tensors */ + TensorPool + weight_master_pool; /**< weight pool for store master float 32 weights */ TensorPool tensor_pool; /**< tensor pool to request tensors */ std::map> async_task_eos; diff --git a/nntrainer/tensor/memory_pool.cpp b/nntrainer/tensor/memory_pool.cpp index 782dce1fb8..55ce027f43 100644 --- a/nntrainer/tensor/memory_pool.cpp +++ b/nntrainer/tensor/memory_pool.cpp @@ -107,6 +107,8 @@ void MemoryPool::allocate() { msg.append(std::to_string(seq++)); PROFILE_MEM_ALLOC(mem_pool, pool_size, msg); #endif + + ml_loge("allocate memory: %zu", pool_size); } /** diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index 4f1e8e0721..d01cd26378 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -1029,6 +1029,7 @@ Tensor &Tensor::add(Tensor const &m, Tensor &output, float const alpha) const { ele_add(e.buffer_size, buf, m_buf, out_buf, alpha, 0, e.strides[3], strides[3]); }; + apply_broadcast(m, f, output); #else throw std::invalid_argument("Error: enable-fp16 is not enabled"); @@ -3010,6 +3011,13 @@ void Tensor::copyData(const Tensor &from) { throw std::runtime_error("Cannot copy non-contiguous tensor"); } + if (size() == 0) { + TensorDim dim = from.getDim(); + dim.setDataType(getDataType()); + Tensor t = Tensor(dim, true); + swap(t, *this); + } + if (size() != from.size()) throw std::invalid_argument("Size of tensor to copy must match"); @@ -3065,6 +3073,15 @@ Tensor Tensor::clone() const { return t; } +Tensor Tensor::clone(ml::train::TensorDim::DataType type) const { + TensorDim dim = getDim(); + dim.setDataType(type); + Tensor t(dim, true); + t.copyData(*this); + t.name = name; + return t; +} + void Tensor::reshape(const TensorDim &d) { NNTR_THROW_IF(!contiguous, std::invalid_argument) diff --git a/nntrainer/tensor/tensor.h b/nntrainer/tensor/tensor.h index 211334da40..bbf6370ae0 100644 --- a/nntrainer/tensor/tensor.h +++ b/nntrainer/tensor/tensor.h @@ -1680,6 +1680,12 @@ class Tensor { */ Tensor clone() const; + /** + * @brief Convient wrapper for inplace copy of @a this. + * @retval Copied version of this + */ + Tensor clone(ml::train::TensorDim::DataType type) const; + /** * @brief Save the Tensor into file * @param[in] file output file stream @@ -1887,6 +1893,7 @@ class Tensor { const std::array getStrides() const noexcept { return strides; } + /** * @brief Get linear index given the n-d index */ @@ -1923,6 +1930,54 @@ class Tensor { return continuous; } + /** + * @brief Check if data is valid number + */ + bool checkDataValidation(bool print = false) const { + bool ret = true; + + // It only support FP16 and FP32 + if (getDataType() != Tdatatype::FP32 && getDataType() != Tdatatype::FP16) + return true; + + std::string data; + for (unsigned int i = 0, len = size(); i < len; ++i) { + if (print) { + if (getDataType() == Tdatatype::FP32) { + data.append(std::to_string(static_cast(getData()[i])) + + ", "); + } +#ifdef ENABLE_FP16 + else if (getDataType() == Tdatatype::FP16) { + data.append(std::to_string(static_cast(getData<_FP16>()[i])) + + ", "); + } +#endif + } + + if (getDataType() == Tdatatype::FP32) { + if (!std::isfinite(getData()[i])) { + ret = false; + break; + } + } +#ifdef ENABLE_FP16 + else if (getDataType() == Tdatatype::FP16) { + if (!std::isfinite(static_cast(getData<_FP16>()[i]))) { + ret = false; + break; + } + } +#endif + } + + if (print && !ret) { + ml_loge("Tensor[%s]: %s", getName().c_str(), data.c_str()); + } + + return ret; + } + /** * @brief Get name of the tensor * diff --git a/nntrainer/tensor/weight.cpp b/nntrainer/tensor/weight.cpp index 44f1f015b1..19d54584d1 100644 --- a/nntrainer/tensor/weight.cpp +++ b/nntrainer/tensor/weight.cpp @@ -27,11 +27,26 @@ Weight::Weight(const TensorDim &dim, const Tensor::Initializer init, regularizer_constant(reg_const), decay(decay_const), clip_by_global_norm(max_norm), - output_axis(axis) { + output_axis(axis), + var_master(nullptr) { if (init == Tensor::Initializer::NONE) throw std::invalid_argument("Weight initializer cannot be none"); if (regularizer == WeightRegularizer::UNKNOWN) throw std::invalid_argument("Weight regularizer unknown"); } +void Weight::applyGradient(double lr) { + if (var_master.get()) { + Tensor grad_ = grad->clone(var_master->getDataType()); + var_master->add_i(grad_, -lr); + } else { + var->add_i(*grad.get(), -lr); + } +} + +void Weight::applyMaster() { + if (var_master.get()) + var->copyData(*var_master); +} + } // namespace nntrainer diff --git a/nntrainer/tensor/weight.h b/nntrainer/tensor/weight.h index bd1651bd15..663fea138e 100644 --- a/nntrainer/tensor/weight.h +++ b/nntrainer/tensor/weight.h @@ -6,6 +6,7 @@ * @date 22 September 2020 * @see https://github.com/nnstreamer/nntrainer * @author Parichay Kapoor + * @author Jiho Chu * @bug No known bugs except for NYI items * @brief This is Weight Class for Neural Network * @@ -105,7 +106,8 @@ class Weight : public Var_Grad { regularizer_constant(1.0f), decay(0.0f), clip_by_global_norm(0.0f), - output_axis(output_axis_) {} + output_axis(output_axis_), + var_master(nullptr) {} /** * @brief Construct a new Weight object @@ -114,17 +116,19 @@ class Weight : public Var_Grad { * @param g ptr to already created gradient tensor * @param reg Regularizer for the weight * @param reg_const Constant multiplier for regularizer + * @param v_m ptr to already created variable master tensor */ explicit Weight(Tensor *v, Tensor *g, const WeightRegularizer reg, const float reg_const, const float decay, bool is_dependent = false, const float max_norm = 0.0f, - unsigned int output_axis_ = 3) : + unsigned int output_axis_ = 3, Tensor *v_m = nullptr) : Var_Grad(v, g, is_dependent), regularizer(reg), regularizer_constant(reg_const), decay(decay), clip_by_global_norm(max_norm), - output_axis(output_axis_) {} + output_axis(output_axis_), + var_master(std::shared_ptr(v_m, [](void *) {})) {} /** * @brief Swap for weight @@ -141,7 +145,9 @@ class Weight : public Var_Grad { swap(lhs.decay, rhs.decay); swap(lhs.clip_by_global_norm, rhs.clip_by_global_norm); swap(lhs.output_axis, rhs.output_axis); + swap(lhs.var_master, rhs.var_master); swap(lhs.opt_vars, rhs.opt_vars); + swap(lhs.opt_master_vars, rhs.opt_master_vars); } /** @@ -194,14 +200,27 @@ class Weight : public Var_Grad { */ void clearOptimizerVariables() { opt_vars.clear(); } + /** + * @brief Clear optimizer variables + */ + void clearOptimizerMasterVariables() { opt_master_vars.clear(); } + /** * @brief Add optimizer variables - * @param dim Optimizer variable dimension + * @param tensors Optimizer variable */ void setOptimizerVariables(std::vector tensors) { opt_vars = tensors; } + /** + * @brief Add optimizer master variables + * @param tensors Optimizer master variable + */ + void setOptimizerMasterVariables(std::vector tensors) { + opt_master_vars = tensors; + } + /** * @brief Get optimizer variable reference * @param idx Index of the optimizer variable to get @@ -209,12 +228,27 @@ class Weight : public Var_Grad { */ Tensor &getOptimizerVariableRef(unsigned int idx) { return *opt_vars[idx]; } + /** + * @brief Get optimizer variable reference + * @param idx Index of the optimizer variable to get + * @retval Reference of the optimizer variable + */ + Tensor &getOptimizerMasterVariableRef(unsigned int idx) { + return *opt_master_vars[idx]; + } + /** * @brief Get number of optimizer variable * @retval number of optimizer variable */ int getNumOptVariable() { return opt_vars.size(); } + /** + * @brief Get number of optimizer variable + * @retval number of optimizer variable + */ + int getNumOptMasterVariable() { return opt_master_vars.size(); } + /** * @brief Get axis of Weight * @retval axis of Wegiht @@ -245,6 +279,11 @@ class Weight : public Var_Grad { return 0; } + /** + * @brief Apply scaler for gradient of the weight + */ + void applyScaler(float scale = 1.0f) { grad->divide_i(scale); } + /** * @brief Calculate gradient from the regularization of the weight */ @@ -264,7 +303,19 @@ class Weight : public Var_Grad { /** * @brief Apply the gradient to the weight */ - void applyGradient(double lr) { var->add_i(*grad.get(), -lr); } + void applyGradient(double lr); + + /** + * @brief Apply the master weight to the weight + */ + void applyMaster(); + + /** + * @brief Get the variable tensor + * + * @return Tensor Variable tensor + */ + Tensor *getVariableMasterRef() const { return var_master.get(); } /** * @brief Check if the gradient is supposed to be clipped by global norm with @@ -308,7 +359,10 @@ class Weight : public Var_Grad { float decay; /**< constant factor for the weight decay */ float clip_by_global_norm; /**< constant factor to clip gradient by L2 norm */ unsigned int output_axis; + std::shared_ptr + var_master; /**< variable master tensor for mixed tensor types */ std::vector opt_vars; /**< optimizer variables */ + std::vector opt_master_vars; /**< optimizer master variables */ /** * @brief Apply the weight decay to the weight From 4f0447cc5da58891d71b1a8bd1b3d943d93f6ed8 Mon Sep 17 00:00:00 2001 From: Jiho Chu Date: Wed, 6 Mar 2024 11:12:02 +0900 Subject: [PATCH 4/8] [Test] Add conv2d test for fp16 It adds tests for conv2d fp16 test. Signed-off-by: Jiho Chu --- nntrainer/optimizers/adam.cpp | 57 ++++-- nntrainer/optimizers/optimizer_context.cpp | 15 ++ nntrainer/optimizers/optimizer_context.h | 15 ++ .../layers/unittest_layers_convolution2d.cpp | 182 ++++++++++++++++++ 4 files changed, 257 insertions(+), 12 deletions(-) diff --git a/nntrainer/optimizers/adam.cpp b/nntrainer/optimizers/adam.cpp index 18c0a0fcc1..d41e51640f 100644 --- a/nntrainer/optimizers/adam.cpp +++ b/nntrainer/optimizers/adam.cpp @@ -79,26 +79,59 @@ void Adam::applyGradient(RunOptimizerContext &context) { Tensor &wm = context.getOptimizerVariable(AdamParams::wm); Tensor &wv = context.getOptimizerVariable(AdamParams::wv); - wm.multiply_i(beta1); - wm.add_i(x_grad, 1.0f - beta1); + if (context.getNumOptMasterVariable() != 0) { + Tensor &wm_m = context.getOptimizerMasterVariable(AdamParams::wm); + Tensor &wv_m = context.getOptimizerMasterVariable(AdamParams::wv); + Tensor x_grad_ = x_grad.clone(wm_m.getDataType()); - wv.multiply_i(beta2); - wv.add_i(x_grad.multiply(x_grad), 1.0f - beta2); + wm_m.multiply_i(beta1); + wm_m.add_i(x_grad_, 1.0f - beta1); + + wv_m.multiply_i(beta2); + wv_m.add_i(x_grad_.multiply(x_grad_), 1.0f - beta2); + + wm.copyData(wm_m); + wv.copyData(wv_m); + } else { + wm.multiply_i(beta1); + wm.add_i(x_grad, 1.0f - beta1); + + wv.multiply_i(beta2); + wv.add_i(x_grad.multiply(x_grad), 1.0f - beta2); + } if (torch_ref) { - Tensor denom = wv.apply(sqrtFloat); - denom.divide_i(sqrtFloat(biasCorrection2)); - denom.add_i(epsilon); - wm.divide(denom, x_grad); + if (x_grad.getDataType() == ml::train::TensorDim::DataType::FP32) { + Tensor denom = wv.apply(sqrtFloat); + denom.divide_i(sqrtFloat(biasCorrection2)); + denom.add_i(epsilon); + wm.divide(denom, x_grad); +#ifdef ENABLE_FP16 + } else if (x_grad.getDataType() == ml::train::TensorDim::DataType::FP16) { + Tensor denom = wv.apply<_FP16>(sqrtFloat<_FP16>); + denom.divide_i(sqrtFloat(biasCorrection2)); + denom.add_i(epsilon); + wm.divide(denom, x_grad); +#endif + } else { + throw std::runtime_error("Not supported datatype"); + } context.applyGradient(context.getLearningRate() / biasCorrection1); - } else { - std::function sqrtEps = [epsilon](double f) { - return 1 / (sqrtDouble(f) + epsilon); + auto sqrtEps = [epsilon](T f) -> T { + return 1 / (static_cast(sqrtDouble(f)) + static_cast(epsilon)); }; - x_grad = wv.apply(sqrtEps, x_grad); + if (x_grad.getDataType() == ml::train::TensorDim::DataType::FP32) + x_grad = wv.apply(sqrtEps, x_grad); +#ifdef ENABLE_FP16 + else if (x_grad.getDataType() == ml::train::TensorDim::DataType::FP16) + x_grad = wv.apply<_FP16>(sqrtEps, x_grad); +#endif + else + throw std::runtime_error("Not supported datatype"); + x_grad.multiply_i(wm); context.applyGradient(getUpdatedLearningRate(context.getIteration(), context.getLearningRate())); diff --git a/nntrainer/optimizers/optimizer_context.cpp b/nntrainer/optimizers/optimizer_context.cpp index da4cd1f7e9..5c282cdf38 100644 --- a/nntrainer/optimizers/optimizer_context.cpp +++ b/nntrainer/optimizers/optimizer_context.cpp @@ -36,6 +36,21 @@ Tensor &RunOptimizerContext::getOptimizerVariable(unsigned int idx) const { return weight->getOptimizerVariableRef(idx); } +/** + * @brief Get the optimizer variable associated to this weight + */ +Tensor & +RunOptimizerContext::getOptimizerMasterVariable(unsigned int idx) const { + return weight->getOptimizerMasterVariableRef(idx); +} + +/** + * @brief Get number of optimizer master variable + */ +int RunOptimizerContext::getNumOptMasterVariable() { + return weight->getNumOptMasterVariable(); +} + /** * @brief Apply the gradient with the given learning rate */ diff --git a/nntrainer/optimizers/optimizer_context.h b/nntrainer/optimizers/optimizer_context.h index 62f9e0945d..ea4980ba06 100644 --- a/nntrainer/optimizers/optimizer_context.h +++ b/nntrainer/optimizers/optimizer_context.h @@ -61,6 +61,21 @@ class RunOptimizerContext { */ Tensor &getOptimizerVariable(unsigned int idx) const; + /** + * @brief Get the optimizer Master variable associated to this weight + * + * @param idx Identifier of the associated weight + * @return Tensor& Reference to the optimizer variable + */ + Tensor &getOptimizerMasterVariable(unsigned int idx) const; + + /** + * @brief Get number of the optimizer Master variable + * + * @return number of optimizer master variable + */ + int getNumOptMasterVariable(); + /** * @brief Check if run context is set and is ready to use * diff --git a/test/unittest/layers/unittest_layers_convolution2d.cpp b/test/unittest/layers/unittest_layers_convolution2d.cpp index 724c79079b..92d9c593e7 100644 --- a/test/unittest/layers/unittest_layers_convolution2d.cpp +++ b/test/unittest/layers/unittest_layers_convolution2d.cpp @@ -198,3 +198,185 @@ GTEST_PARAMETER_TEST( conv2d_mb_valid_drop_last, conv2d_sb_no_overlap, conv2d_mb_no_overlap, conv2d_sb_1x1_kernel, conv2d_mb_1x1_kernel, conv2d_sb_dilation, conv2d_mb_dilation, conv2d_sb_same_dilation, conv2d_mb_same_dilation)); + +#ifdef ENABLE_FP16 +auto conv2d_sb_minimum_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"filters=3", "kernel_size=2,2"}, "1:1:4:4", + "conv2d_sb_minimum_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_minimum_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"filters=3", "kernel_size=2,2"}, "3:1:4:4", + "conv2d_mb_minimum_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_same_remain_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"filters=2", "kernel_size=3,3", "padding=same"}, "1:1:4:4", + "conv2d_sb_same_remain_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_same_remain_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"filters=2", "kernel_size=3,3", "padding=same"}, "3:1:4:4", + "conv2d_mb_same_remain_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_same_uneven_remain_1_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "stride=2,2", + "padding=same", + }, + "1:3:4:4", "conv2d_sb_same_uneven_remain_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_same_uneven_remain_2_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "stride=2,2", + "padding=0,1,0,1", + }, + "1:3:4:4", "conv2d_sb_same_uneven_remain_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_same_uneven_remain_1_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "stride=2,2", + "padding=same", + }, + "3:3:4:4", "conv2d_mb_same_uneven_remain_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_same_uneven_remain_2_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "stride=2,2", + "padding=0,1,0,1", + }, + "3:3:4:4", "conv2d_mb_same_uneven_remain_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_valid_drop_last_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "stride=2,2", + "padding=valid", + }, + "1:3:7:7", "conv2d_sb_valid_drop_last_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_valid_drop_last_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "stride=2,2", + "padding=valid", + }, + "3:3:7:7", "conv2d_mb_valid_drop_last_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_no_overlap_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"filters=3", "kernel_size=2,2", "stride=3,3"}, "1:2:5:5", + "conv2d_sb_no_overlap_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_no_overlap_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=3", + "kernel_size=2,2", + "stride=3,3", + }, + "3:2:5:5", "conv2d_mb_no_overlap_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_1x1_kernel_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"filters=3", "kernel_size=1,1", "stride=2,2"}, "1:2:5:5", + "conv2d_sb_1x1_kernel_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_1x1_kernel_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=3", + "kernel_size=1,1", + "stride=2,2", + }, + "3:2:5:5", "conv2d_mb_1x1_kernel_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_dilation_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "dilation=2,2", + }, + "1:3:11:11", "conv2d_sb_dilation_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_dilation_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "dilation=2,2", + }, + "3:3:11:11", "conv2d_mb_dilation_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_same_dilation_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "padding=same", + "dilation=2,2", + }, + "1:3:11:11", "conv2d_sb_same_dilation_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_same_dilation_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "padding=same", + "dilation=2,2", + }, + "3:3:11:11", "conv2d_mb_same_dilation_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +GTEST_PARAMETER_TEST( + Convolution2D16, LayerGoldenTest, + ::testing::Values(conv2d_sb_minimum_w16a16, conv2d_mb_minimum_w16a16, + conv2d_sb_same_remain_w16a16, conv2d_mb_same_remain_w16a16, + conv2d_sb_same_uneven_remain_1_w16a16, + conv2d_sb_same_uneven_remain_2_w16a16, + conv2d_mb_same_uneven_remain_1_w16a16, + conv2d_mb_same_uneven_remain_2_w16a16, + conv2d_sb_valid_drop_last_w16a16, + conv2d_mb_valid_drop_last_w16a16, + conv2d_sb_no_overlap_w16a16, conv2d_mb_no_overlap_w16a16, + conv2d_sb_1x1_kernel_w16a16, conv2d_mb_1x1_kernel_w16a16, + conv2d_sb_dilation_w16a16, conv2d_mb_dilation_w16a16, + conv2d_sb_same_dilation_w16a16, + conv2d_mb_same_dilation_w16a16)); +#endif From e0efd10b8dc6c7a4df0d4669d85d9e2cc1c85480 Mon Sep 17 00:00:00 2001 From: Jiho Chu Date: Thu, 7 Mar 2024 16:10:48 +0900 Subject: [PATCH 5/8] [Fix] fix doxygen comments It fixes doygen comments from clang format checker. Signed-off-by: Jiho Chu --- nntrainer/graph/network_graph.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 23eaa3afff..e98d6b7b17 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -86,7 +86,8 @@ int NetworkGraph::compile(const std::string &loss_type) { status = checkCompiledGraph(); NN_RETURN_STATUS(); - /* @note It can be integrated with addLossLayer method + /** + * @note It can be integrated with addLossLayer method * if it removes adding loss layer to the model directly. */ for (auto iter = cbegin(); iter != cend(); iter++) { @@ -456,7 +457,7 @@ void NetworkGraph::backwarding( // check first layer's derivative is valid // loss scale is adjusted between 1.0f ~ 256.0f - // @TODO provide max scale property + // @todo provide max scale property auto &ln = *(cbegin() + 1); if (loss_scale != 0.0f && !ln->getRunContext().validateDerivatives()) { // It will not apply train results if data is invalid From 006c8283a2bbc6582575aef0ea9db695da34a8b2 Mon Sep 17 00:00:00 2001 From: Jiho Chu Date: Thu, 7 Mar 2024 20:17:03 +0900 Subject: [PATCH 6/8] [packaging] install loss_layer.h It installs loss_layer header file for custom loss layer. Signed-off-by: Jiho Chu --- debian/nntrainer-dev.install | 1 + packaging/nntrainer.spec | 1 + 2 files changed, 2 insertions(+) diff --git a/debian/nntrainer-dev.install b/debian/nntrainer-dev.install index 4fd55b3774..bd4c344dc4 100644 --- a/debian/nntrainer-dev.install +++ b/debian/nntrainer-dev.install @@ -23,6 +23,7 @@ /usr/include/nntrainer/layer_context.h /usr/include/nntrainer/layer_devel.h /usr/include/nntrainer/layer_impl.h +/usr/include/nntrainer/loss_layer.h # custom layer kits /usr/include/nntrainer/app_context.h # logger diff --git a/packaging/nntrainer.spec b/packaging/nntrainer.spec index 7cf6cd1493..1e1f3029c0 100644 --- a/packaging/nntrainer.spec +++ b/packaging/nntrainer.spec @@ -554,6 +554,7 @@ cp -r result %{buildroot}%{_datadir}/nntrainer/unittest/ %{_includedir}/nntrainer/layer_context.h %{_includedir}/nntrainer/layer_devel.h %{_includedir}/nntrainer/layer_impl.h +%{_includedir}/nntrainer/loss_layer.h # custom layer kits %{_includedir}/nntrainer/app_context.h # optimizer headers From 1ef1f72b47f967adede3ce31bb6ea44e700c52f7 Mon Sep 17 00:00:00 2001 From: Jiho Chu Date: Thu, 14 Mar 2024 19:31:50 +0900 Subject: [PATCH 7/8] [Layers] Modify layers for data type It is assumed that activations and weight are fully compotaible, so it's unnecessary to be converted to. input layer and loss layres are different, cause input data and label data is assumed to be always float 32 type now. Signed-off-by: Jiho Chu --- nntrainer/layers/bn_layer.cpp | 12 +- nntrainer/layers/conv2d_layer.cpp | 160 +++++---- nntrainer/layers/input_layer.cpp | 9 +- nntrainer/layers/layer_context.cpp | 51 +++ nntrainer/layers/layer_context.h | 38 +++ nntrainer/layers/layer_devel.h | 5 + nntrainer/layers/layer_node.cpp | 27 +- nntrainer/layers/layer_node.h | 5 + .../loss/cross_entropy_sigmoid_loss_layer.cpp | 3 + .../loss/cross_entropy_softmax_loss_layer.cpp | 36 +- nntrainer/layers/loss/loss_layer.cpp | 5 +- nntrainer/layers/loss/loss_layer.h | 21 ++ nntrainer/layers/loss/meson.build | 4 +- nntrainer/layers/loss/mse_loss_layer.cpp | 73 ++++- nntrainer/layers/lstm.cpp | 42 ++- nntrainer/layers/lstm.h | 1 - nntrainer/layers/pooling2d_layer.cpp | 310 +++++++++++------- nntrainer/layers/reshape_layer.cpp | 1 + 18 files changed, 584 insertions(+), 219 deletions(-) diff --git a/nntrainer/layers/bn_layer.cpp b/nntrainer/layers/bn_layer.cpp index 1723ac677f..e978b1ef59 100644 --- a/nntrainer/layers/bn_layer.cpp +++ b/nntrainer/layers/bn_layer.cpp @@ -111,6 +111,12 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) { context.requestWeight(dim, bnparams_beta, WeightRegularizer::NONE, 1.0f, bias_decay, "beta", true); + /** + * @note declare weigth dimention with activation datatype + */ + TensorDim w_dim = dim; + w_dim.setDataType(in_dim.getDataType()); + /** * caches the deviation -> input - avg(input) * @todo check if avoiding this storage and adding dependency on input (no @@ -121,7 +127,7 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) { TensorLifespan::ITERATION_LIFESPAN); /** caches the inverse standard deviation */ wt_idx[BNParams::invstd] = - context.requestTensor(dim, "invstd", Tensor::Initializer::NONE, false, + context.requestTensor(w_dim, "invstd", Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN); /** * Temporary tensor to store the full sized tensors in order to allow batch @@ -136,13 +142,13 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) { * caches variance + epsilon as well. */ wt_idx[BNParams::cvar] = - context.requestTensor(dim, "cvar", Tensor::Initializer::NONE, false, + context.requestTensor(w_dim, "cvar", Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN); /** * Temporary tensor to store the reduced tensors along the axes_to_reduce. */ wt_idx[BNParams::t_reduced] = - context.requestTensor(dim, "tensor_reduced", Tensor::Initializer::NONE, + context.requestTensor(w_dim, "tensor_reduced", Tensor::Initializer::NONE, false, TensorLifespan::FORWARD_DERIV_LIFESPAN); } diff --git a/nntrainer/layers/conv2d_layer.cpp b/nntrainer/layers/conv2d_layer.cpp index c059ae9caf..5d9dbc1e19 100644 --- a/nntrainer/layers/conv2d_layer.cpp +++ b/nntrainer/layers/conv2d_layer.cpp @@ -38,7 +38,8 @@ namespace { static TensorDim calcCol2ImOutputDim(const TensorDim &out, const TensorDim &kdim) { - return TensorDim({kdim.getFeatureLen(), out.width() * out.height()}); + return TensorDim({kdim.getFeatureLen(), out.width() * out.height()}, + out.getTensorType()); } /** @@ -56,7 +57,10 @@ static void col2im(const Tensor &col_matrix, const TensorDim &kdim, const std::array &mstride, const std::array &dilation, Tensor &image) { - auto [pt, pb, pl, pr] = padding; + auto pt = padding[0]; + auto pb = padding[1]; + auto pl = padding[2]; + auto pr = padding[3]; unsigned k_height = kdim.height(); unsigned k_width = kdim.width(); @@ -84,32 +88,48 @@ static void col2im(const Tensor &col_matrix, const TensorDim &kdim, int h_stride_end = im_eff_height - eff_k_height - pt; int w_stride_end = im_eff_width - eff_k_width - pl; - unsigned col_w = 0; - for (int hs = -pt; hs <= h_stride_end; hs += hstride) { - for (int ws = -pl; ws <= w_stride_end; ws += wstride) { - unsigned col_h = 0; - int patch_height_end = hs + eff_k_height; - int patch_width_end = ws + eff_k_width; - for (unsigned c = 0; c < im_channel; c++) { - for (int h = hs; h < patch_height_end; h += hdilation) { - if (h < 0 || im_height <= h) { - col_h += k_width; - continue; - } - for (int w = ws; w < patch_width_end; w += wdilation) { - if (w < 0 || im_width <= w) { - col_h++; + auto apply_data = [&](T *val) { + unsigned col_w = 0; + for (int hs = -pt; hs <= h_stride_end; hs += hstride) { + for (int ws = -pl; ws <= w_stride_end; ws += wstride) { + unsigned col_h = 0; + int patch_height_end = hs + eff_k_height; + int patch_width_end = ws + eff_k_width; + for (unsigned c = 0; c < im_channel; c++) { + for (int h = hs; h < patch_height_end; h += hdilation) { + if (h < 0 || im_height <= h) { + col_h += k_width; continue; } - - float *val = image.getAddress(0, c, h, w); - *val += col_matrix.getValue(0, 0, col_h, col_w); - col_h++; + for (int w = ws; w < patch_width_end; w += wdilation) { + if (w < 0 || im_width <= w) { + col_h++; + continue; + } + + val = image.getAddress(0, c, h, w); + *val += col_matrix.getValue(0, 0, col_h, col_w); + col_h++; + } } } + col_w++; } - col_w++; } + }; + + if (image.getDataType() == nntrainer::Tdatatype::FP32) { + float val; + apply_data(&val); + } +#ifdef ENABLE_FP16 + else if (image.getDataType() == nntrainer::Tdatatype::FP16) { + _FP16 val; + apply_data(&val); + } +#endif + else { + throw std::runtime_error("Not supported datatype"); } } @@ -179,7 +199,10 @@ static void im2col(const Tensor &in, const TensorDim &kdim, // } */ - auto [pt, pb, pl, pr] = padding; + auto pt = padding[0]; + auto pb = padding[1]; + auto pl = padding[2]; + auto pr = padding[3]; unsigned int channel = in.channel(); int in_height = in.height(); @@ -198,46 +221,62 @@ static void im2col(const Tensor &in, const TensorDim &kdim, unsigned int out_width = (width - eff_k_width) / mstride[1] + 1; out.reshape( - TensorDim({out_height * out_width, in.channel() * k_height * k_width})); - float *out_data = out.getData(); - - int h_stride_end = height - eff_k_height - pt; - int w_stride_end = width - eff_k_width - pl; - - /// get a patch, size of kernel - /// hs is height_strided, ws is width_strided - unsigned int owidth = out.width(); - unsigned int base_im_w = 0; - for (int hs = -pt; hs <= h_stride_end; hs += mstride[0]) { - unsigned int base_im_h = 0; - int patch_height_end = eff_k_height + hs; - /// map the patch to a single line looping through channel - for (unsigned int c = 0; c < channel; ++c) { - for (int h = hs; h < patch_height_end; h += dilation[0]) { - if (h < 0 || in_height <= h) { - base_im_h += k_width; - continue; - } - - unsigned int im_w = base_im_w; - for (int ws = -pl; ws <= w_stride_end; ws += mstride[1]) { - unsigned int im_h = base_im_h; - int patch_width_end = eff_k_width + ws; + TensorDim({out_height * out_width, in.channel() * k_height * k_width}, + in.getTensorType())); + + auto apply_data = [&](T *out_data) { + int h_stride_end = height - eff_k_height - pt; + int w_stride_end = width - eff_k_width - pl; + + /// get a patch, size of kernel + /// hs is height_strided, ws is width_strided + unsigned int owidth = out.width(); + unsigned int base_im_w = 0; + for (int hs = -pt; hs <= h_stride_end; hs += mstride[0]) { + unsigned int base_im_h = 0; + int patch_height_end = eff_k_height + hs; + /// map the patch to a single line looping through channel + for (unsigned int c = 0; c < channel; ++c) { + for (int h = hs; h < patch_height_end; h += dilation[0]) { + if (h < 0 || in_height <= h) { + base_im_h += k_width; + continue; + } - for (int w = ws; w < patch_width_end; w += dilation[1]) { - if (w < 0 || in_width <= w) { + unsigned int im_w = base_im_w; + for (int ws = -pl; ws <= w_stride_end; ws += mstride[1]) { + unsigned int im_h = base_im_h; + int patch_width_end = eff_k_width + ws; + + for (int w = ws; w < patch_width_end; w += dilation[1]) { + if (w < 0 || in_width <= w) { + im_h++; + continue; + } + out_data[im_w * owidth + im_h] = in.getValue(0, c, h, w); im_h++; - continue; } - out_data[im_w * owidth + im_h] = in.getValue(0, c, h, w); - im_h++; + im_w++; } - im_w++; + base_im_h += k_width; } - base_im_h += k_width; } + base_im_w += out_width; } - base_im_w += out_width; + }; + + if (out.getDataType() == nntrainer::Tdatatype::FP32) { + float *out_data = out.getData(); + apply_data(out_data); + } +#ifdef ENABLE_FP16 + else if (out.getDataType() == nntrainer::Tdatatype::FP16) { + _FP16 *out_data = out.getData<_FP16>(); + apply_data(out_data); + } +#endif + else { + throw std::runtime_error("Not supported datatype"); } } @@ -279,9 +318,11 @@ void Conv2DLayer::finalize(InitLayerContext &context) { auto &dilation = std::get>(conv_props); - TensorDim kernel_dim = - TensorDim(filter_size, in_dim.channel(), kernel_size[0], kernel_size[1]); - TensorDim bias_dim = TensorDim(1, filter_size, 1, 1); + auto in_t_type = in_dim.getTensorType(); + in_t_type.data_type = context.getWeightDataType(); + TensorDim kernel_dim = TensorDim(filter_size, in_dim.channel(), + kernel_size[0], kernel_size[1], in_t_type); + TensorDim bias_dim = TensorDim(1, filter_size, 1, 1, in_t_type); padding = std::get(conv_props) .compute(in_dim, kernel_dim, {stride[0], stride[1]}, @@ -309,6 +350,7 @@ void Conv2DLayer::finalize(InitLayerContext &context) { out_dim.channel(filter_size); out_dim.height((eff_in_height - eff_k_height) / stride[0] + 1); out_dim.width((eff_in_width - eff_k_width) / stride[1] + 1); + out_dim.setTensorType(in_dim.getTensorType()); context.setOutputDimensions({out_dim}); NNTR_THROW_IF(eff_in_height < kernel_size[0] || eff_in_width < kernel_size[1], diff --git a/nntrainer/layers/input_layer.cpp b/nntrainer/layers/input_layer.cpp index eabd40b297..240a51cb9c 100644 --- a/nntrainer/layers/input_layer.cpp +++ b/nntrainer/layers/input_layer.cpp @@ -33,8 +33,7 @@ namespace nntrainer { static constexpr size_t SINGLE_INOUT_IDX = 0; InputLayer::InputLayer() : - Layer(), - input_props(props::Normalization(), props::Standardization()) {} + Layer(), input_props(props::Normalization(), props::Standardization()) {} void InputLayer::setProperty(const std::vector &values) { auto remain_props = loadProperties(values, input_props); @@ -47,7 +46,7 @@ void InputLayer::forwarding(RunLayerContext &context, bool training) { Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); if (!context.executeInPlace()) { Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - hidden_.copy(input_); + hidden_.copyData(input_); } if (std::get(input_props)) @@ -70,6 +69,10 @@ void InputLayer::finalize(InitLayerContext &context) { std::vector output_dims = context.getInputDimensions(); + NNTR_THROW_IF(output_dims.size() != 1, std::invalid_argument); + + output_dims[0].setTensorType( + {context.getFormat(), context.getActivationDataType()}); context.setOutputDimensions(output_dims); } diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index 04bc576c38..359a306d36 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -153,6 +153,16 @@ Tensor &RunLayerContext::getWeight(unsigned int idx) const { return weights[idx]->getVariableRef(); } +/** + * @brief Get the Weight tensor object + * + * @param idx Identifier of the weight + * @return Tensor& Reference to the weight tensor + */ +Tensor *RunLayerContext::getWeightMaster(unsigned int idx) const { + return weights[idx]->getVariableMasterRef(); +} + /** * @brief Get the Weight Gradient tensor object * @@ -178,6 +188,18 @@ Tensor &RunLayerContext::getWeightOptVar(unsigned int idx, return weights[idx]->getOptimizerVariableRef(jdx); } +/** + * @brief Get the Weight Optimizer Variable tensor object + * + * @param idx Identifier of the weight + * @param jdx Identifier of the optimizer variables + * @return Tensor& Reference to the weight optimizer variable tensor + */ +Tensor &RunLayerContext::getWeightOptMasterVar(unsigned int idx, + unsigned int jdx) const { + return weights[idx]->getOptimizerMasterVariableRef(jdx); +} + /** * @brief Get the Number of Weight Optimizer Variable tensor object * @@ -188,6 +210,16 @@ unsigned int RunLayerContext::getNumWeightOptVar(unsigned int idx) const { return weights[idx]->getNumOptVariable(); } +/** + * @brief Get the Number of Weight Optimizer Variable tensor object + * + * @param idx Identifier of the weight + * @return int Number of the weight optimizer variable + */ +unsigned int RunLayerContext::getNumWeightOptMasterVar(unsigned int idx) const { + return weights[idx]->getNumOptMasterVariable(); +} + /** * @brief Get regularization loss for the weight * @@ -327,6 +359,25 @@ Tensor &RunLayerContext::getOutgoingDerivative(unsigned int idx) { return getInputGrad(idx); } +bool RunLayerContext::validateDerivatives() { + auto num_in = getNumInputs(); + auto num_out = getNumOutputs(); + + for (unsigned int i = 0; i < num_in; ++i) { + auto deriv = getIncomingDerivative(i); + if (deriv.checkDataValidation(false) == false) + return false; + } + + for (unsigned int i = 0; i < num_out; ++i) { + auto deriv = getOutgoingDerivative(i); + if (deriv.checkDataValidation(false) == false) + return false; + } + + return true; +} + /** * @brief Get the Tensor object * diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index 3278cf0f24..e9bdb22ac1 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -443,6 +443,14 @@ class RunLayerContext { */ Tensor &getWeight(unsigned int idx) const; + /** + * @brief Get the Weight master tensor object + * + * @param idx Identifier of the weight + * @return Tensor& Reference to the weight tensor + */ + Tensor *getWeightMaster(unsigned int idx) const; + /** * @brief Get the Weight Gradient tensor object * @@ -461,6 +469,15 @@ class RunLayerContext { */ Tensor &getWeightOptVar(unsigned int idx, unsigned int jdx) const; + /** + * @brief Get the Weight Optimizer Master Variable tensor object + * + * @param idx Identifier of the weight + * @param jdx Identifier of the weight optimizer master variable + * @return Tensor& Reference to the weight optimizer tensor + */ + Tensor &getWeightOptMasterVar(unsigned int idx, unsigned int jdx) const; + /** * @brief Get the Weight name * @@ -571,6 +588,11 @@ class RunLayerContext { */ Tensor &getOutgoingDerivative(unsigned int idx); + /** + * @brief validate input/output derivatives of the layer + */ + bool validateDerivatives(); + /** * @brief Get the Tensor object * @@ -686,6 +708,14 @@ class RunLayerContext { */ unsigned int getNumWeightOptVar(unsigned int idx) const; + /** + * @brief Get the Number of Weight Optimizer Variable tensor object + * + * @param idx Identifier of the weight + * @return unsigned int Number of the weight optimizer variable + */ + unsigned int getNumWeightOptMasterVar(unsigned int idx) const; + /** * @brief Get the number of requested tensors objects * @@ -693,6 +723,14 @@ class RunLayerContext { */ unsigned int getNumTensors() const { return tensors.size(); } + /** + * @brief Set the Weight Optimizer Variable tensor object + * + * @param idx Identifier of the weight + * @param jdx Identifier of the weight optimizer variable + */ + void setWeightOptVars(unsigned int idx, std::vector opts); + /** * @brief Set the batch for the run context * diff --git a/nntrainer/layers/layer_devel.h b/nntrainer/layers/layer_devel.h index 54ce1a0ee9..44a87cc7e9 100644 --- a/nntrainer/layers/layer_devel.h +++ b/nntrainer/layers/layer_devel.h @@ -259,6 +259,11 @@ class Layer { * @return true if supports backwarding, else false */ virtual bool supportBackwarding() const = 0; + + /** + * @brief Set loss scale factor + */ + virtual void setLossScale(float scale) {} }; /// @todo Decide where to put and how to implement(#986) diff --git a/nntrainer/layers/layer_node.cpp b/nntrainer/layers/layer_node.cpp index a7c5f049e4..c5ede8b1c0 100644 --- a/nntrainer/layers/layer_node.cpp +++ b/nntrainer/layers/layer_node.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -464,8 +465,12 @@ void LayerNode::read(std::ifstream &file, bool opt_var) { for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) { if (run_context->isGradientLastAccess(i) && getTrainable()) { /// @note read optimizer variables + auto num_w_opt_m = run_context->getNumWeightOptMasterVar(i); for (unsigned int j = 0; j < run_context->getNumWeightOptVar(i); ++j) { - run_context->getWeightOptVar(i, j).read(file); + if (num_w_opt_m > 0) + run_context->getWeightOptMasterVar(i, j).read(file); + else + run_context->getWeightOptVar(i, j).read(file); } } } @@ -473,7 +478,11 @@ void LayerNode::read(std::ifstream &file, bool opt_var) { for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) { /// @note shared weights are only be read at the first acecss if (run_context->isGradientLastAccess(i)) { - run_context->getWeight(i).read(file); + auto w = run_context->getWeightMaster(i); + if (w) + w->read(file); + else + run_context->getWeight(i).read(file); } } } @@ -488,9 +497,13 @@ void LayerNode::save(std::ofstream &file, bool opt_var) const { if (run_context->isGradientLastAccess(i) && getTrainable()) { // @note save optimizer variables if (run_context->weightHasGradient(i)) { + auto num_w_opt_m = run_context->getNumWeightOptMasterVar(i); for (unsigned int j = 0; j < run_context->getNumWeightOptVar(i); ++j) { - run_context->getWeightOptVar(i, j).save(file); + if (num_w_opt_m > 0) + run_context->getWeightOptMasterVar(i, j).save(file); + else + run_context->getWeightOptVar(i, j).save(file); } } } @@ -499,7 +512,13 @@ void LayerNode::save(std::ofstream &file, bool opt_var) const { // @note shared weights are only be saved at the first access for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) { if (run_context->isGradientLastAccess(i)) { - run_context->getWeight(i).save(file); + if (run_context->getNumWeights()) { + auto w = run_context->getWeightMaster(i); + if (w) + w->save(file); + else + run_context->getWeight(i).save(file); + } } } } diff --git a/nntrainer/layers/layer_node.h b/nntrainer/layers/layer_node.h index c1068b0f56..d87581299f 100644 --- a/nntrainer/layers/layer_node.h +++ b/nntrainer/layers/layer_node.h @@ -897,6 +897,11 @@ class LayerNode final : public ml::train::Layer, public GraphNode { */ bool needsCalcGradient() { return needs_calc_gradient; } + /** + * @brief Set loss scale factor + */ + void setLossScale(float scale) { layer->setLossScale(scale); } + private: /** * @brief Get the Input Layers object diff --git a/nntrainer/layers/loss/cross_entropy_sigmoid_loss_layer.cpp b/nntrainer/layers/loss/cross_entropy_sigmoid_loss_layer.cpp index 60ea113418..feeff2b3d8 100644 --- a/nntrainer/layers/loss/cross_entropy_sigmoid_loss_layer.cpp +++ b/nntrainer/layers/loss/cross_entropy_sigmoid_loss_layer.cpp @@ -61,6 +61,9 @@ void CrossEntropySigmoidLossLayer::calcDerivative(RunLayerContext &context) { Tensor &y = context.getInput(SINGLE_INOUT_IDX); y.apply(ActiFunc::sigmoid, ret_derivative); + + applyLossScale(ret_derivative); + ret_derivative.subtract_i(y2); if (ret_derivative.divide_i(ret_derivative.size()) != ML_ERROR_NONE) { throw std::runtime_error("[CrossEntropySigmoidLossLayer::calcDerivative] " diff --git a/nntrainer/layers/loss/cross_entropy_softmax_loss_layer.cpp b/nntrainer/layers/loss/cross_entropy_softmax_loss_layer.cpp index 53854662ae..c181c60b9a 100644 --- a/nntrainer/layers/loss/cross_entropy_softmax_loss_layer.cpp +++ b/nntrainer/layers/loss/cross_entropy_softmax_loss_layer.cpp @@ -30,9 +30,14 @@ void CrossEntropySoftmaxLossLayer::forwarding(RunLayerContext &context, Tensor &y = context.getInput(SINGLE_INOUT_IDX); // fill the output - auto dataType = y.getDataType(); - if (dataType == ml::train::TensorDim::DataType::FP32) { - hidden_ = y.apply(ActiFunc::softmax, hidden_); + auto out_type = hidden_.getDataType(); + if (out_type == ml::train::TensorDim::DataType::FP32) { + if (y.getDataType() != out_type) { + Tensor y_ = y.clone(out_type); + hidden_ = y_.apply(ActiFunc::softmax, hidden_); + } else { + hidden_ = y.apply(ActiFunc::softmax, hidden_); + } if (context.isLabelAvailable(SINGLE_INOUT_IDX)) { Tensor &y2 = context.getLabel(SINGLE_INOUT_IDX); @@ -43,9 +48,14 @@ void CrossEntropySoftmaxLossLayer::forwarding(RunLayerContext &context, // update the loss value LossLayer::updateLoss(context, l); } - } else if (dataType == ml::train::TensorDim::DataType::FP16) { + } else if (out_type == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 - hidden_ = y.apply(ActiFunc::softmax<_FP16>, hidden_); + if (y.getDataType() != out_type) { + Tensor y_ = y.clone(out_type); + hidden_ = y_.apply(ActiFunc::softmax<_FP16>, hidden_); + } else { + hidden_ = y.apply(ActiFunc::softmax<_FP16>, hidden_); + } if (context.isLabelAvailable(SINGLE_INOUT_IDX)) { Tensor &y2 = context.getLabel(SINGLE_INOUT_IDX); @@ -68,7 +78,8 @@ void CrossEntropySoftmaxLossLayer::calcDerivative(RunLayerContext &context) { Tensor &y = context.getInput(SINGLE_INOUT_IDX); auto dataType = y.getDataType(); - Tensor ret = Tensor("ret", y.getFormat(), y.getDataType()); + + Tensor ret(y.getDim()); if (dataType == ml::train::TensorDim::DataType::FP32) { y.apply(ActiFunc::softmax, ret); } else if (dataType == ml::train::TensorDim::DataType::FP16) { @@ -83,7 +94,18 @@ void CrossEntropySoftmaxLossLayer::calcDerivative(RunLayerContext &context) { /// operation // TODO: verify y and ret_derivative must not be same as loss layer is not // working in-place - ret.subtract(y2, ret_derivative); + if (ret.getDataType() != y2.getDataType()) { + ret.subtract(y2.clone(ret.getDataType()), ret_derivative); + } else { + ret.subtract(y2, ret_derivative); + } + + /** + * loss scale is applied for mixed precision + * every loss layers need to specify this applying code. + */ + applyLossScale(ret_derivative); + if (ret_derivative.divide_i(ret.batch()) != ML_ERROR_NONE) { throw std::runtime_error("[CrossEntropySoftmaxLossLayer::calcDerivative] " "Error when calculating loss"); diff --git a/nntrainer/layers/loss/loss_layer.cpp b/nntrainer/layers/loss/loss_layer.cpp index 40f74717f8..604e35b644 100644 --- a/nntrainer/layers/loss/loss_layer.cpp +++ b/nntrainer/layers/loss/loss_layer.cpp @@ -15,6 +15,9 @@ #include namespace nntrainer { + +LossLayer::LossLayer() : Layer(), loss_scale(0.0f) {} + void LossLayer::finalize(InitLayerContext &context) { std::vector input_dim = context.getInputDimensions(); std::vector output_dim = input_dim; @@ -22,7 +25,7 @@ void LossLayer::finalize(InitLayerContext &context) { d.setDataType( str_converter::from_string("FP32")); - + context.setOutputDimensions(output_dim); } diff --git a/nntrainer/layers/loss/loss_layer.h b/nntrainer/layers/loss/loss_layer.h index 00b520f6e6..c054df9c95 100644 --- a/nntrainer/layers/loss/loss_layer.h +++ b/nntrainer/layers/loss/loss_layer.h @@ -27,6 +27,11 @@ namespace nntrainer { */ class LossLayer : public Layer { public: + /** + * @brief Constructor of Loss Layer + */ + LossLayer(); + /** * @brief Destructor of Loss Layer */ @@ -47,11 +52,19 @@ class LossLayer : public Layer { */ virtual bool supportBackwarding() const override { return true; } + /** + * @brief Set loss scale factor + */ + virtual void setLossScale(float scale) override { loss_scale = scale; } + +private: /** * @copydoc Layer::requireLabel() */ bool requireLabel() const override { return true; } + float loss_scale; /**< loss scale factor */ + protected: /** * @brief update loss @@ -60,6 +73,14 @@ class LossLayer : public Layer { */ void updateLoss(RunLayerContext &context, const Tensor &l); + /** + * @brief apply loss scale + */ + void applyLossScale(Tensor &derivative) { + if (loss_scale != 0.0f) + derivative.multiply_i(loss_scale); + } + Tensor l; /**< loss tensor to store intermediate value to calculate loss value */ }; diff --git a/nntrainer/layers/loss/meson.build b/nntrainer/layers/loss/meson.build index 9fccd0290d..8ec9928101 100644 --- a/nntrainer/layers/loss/meson.build +++ b/nntrainer/layers/loss/meson.build @@ -7,7 +7,9 @@ loss_layer_sources = [ 'constant_derivative_loss_layer.cpp' ] -loss_layer_headers = [] +loss_layer_headers = [ + 'loss_layer.h' +] loss_layer_deps = [] diff --git a/nntrainer/layers/loss/mse_loss_layer.cpp b/nntrainer/layers/loss/mse_loss_layer.cpp index 7f7bd1626f..7350f568a4 100644 --- a/nntrainer/layers/loss/mse_loss_layer.cpp +++ b/nntrainer/layers/loss/mse_loss_layer.cpp @@ -11,6 +11,7 @@ * */ +#include "tensor.h" #include #include @@ -20,24 +21,42 @@ static constexpr size_t SINGLE_INOUT_IDX = 0; void MSELossLayer::forwarding(RunLayerContext &context, bool training) { Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); - Tensor &y = context.getInput(SINGLE_INOUT_IDX); + Tensor &y_ = context.getInput(SINGLE_INOUT_IDX); // hidden_ <- y2 - y; - if (context.isLabelAvailable(SINGLE_INOUT_IDX)) { - Tensor &y2 = context.getLabel(SINGLE_INOUT_IDX); - y2.subtract(y, hidden_); + auto out_type = hidden_.getDataType(); + if (out_type != y_.getDataType()) { + Tensor y = y_.clone(out_type); + if (context.isLabelAvailable(SINGLE_INOUT_IDX)) { + Tensor &y2 = context.getLabel(SINGLE_INOUT_IDX); + y2.subtract(y, hidden_); - /** calculate sum of squares normalized by size */ - float l2norm = hidden_.l2norm(); - l2norm *= l2norm / hidden_.size(); + /** calculate sum of squares normalized by size */ + float l2norm = hidden_.l2norm(); + l2norm *= l2norm / hidden_.size(); - /** wrap in tensor for update loss */ - Tensor l = Tensor(TensorDim(1, 1, 1, 1), &l2norm); - LossLayer::updateLoss(context, l); - } + /** wrap in tensor for update loss */ + Tensor l = Tensor(TensorDim(1, 1, 1, 1), &l2norm); + LossLayer::updateLoss(context, l); + } + // fill the output + hidden_.fill(y); + } else { + if (context.isLabelAvailable(SINGLE_INOUT_IDX)) { + Tensor &y2 = context.getLabel(SINGLE_INOUT_IDX); + y2.subtract(y_, hidden_); + + /** calculate sum of squares normalized by size */ + float l2norm = hidden_.l2norm(); + l2norm *= l2norm / hidden_.size(); - // fill the output - hidden_.fill(y); + /** wrap in tensor for update loss */ + Tensor l = Tensor(TensorDim(1, 1, 1, 1), &l2norm); + LossLayer::updateLoss(context, l); + } + // fill the output + hidden_.fill(y_); + } } void MSELossLayer::calcDerivative(RunLayerContext &context) { @@ -45,9 +64,33 @@ void MSELossLayer::calcDerivative(RunLayerContext &context) { const Tensor &y2 = context.getIncomingDerivative(SINGLE_INOUT_IDX); Tensor &y = context.getInput(SINGLE_INOUT_IDX); - y.subtract(y2, ret_derivative); + const auto &in_type = y.getDataType(); + if (in_type != y2.getDataType()) { + Tensor y2_ = y2.clone(in_type); + y.subtract(y2_, ret_derivative); + } else { + y.subtract(y2, ret_derivative); + } + + applyLossScale(ret_derivative); + float divider = ((float)y.size()) / 2; - if (ret_derivative.divide_i(divider) != ML_ERROR_NONE) { + + /** + * ret_derivative may be eliminated by big divider with fp16 calculation. + * So, it calcuated with larger precision. + */ + int ret; + if (ret_derivative.getDataType() != ml::train::TensorDim::DataType::FP32) { + Tensor ret_derivative_ = + ret_derivative.clone(ml::train::TensorDim::DataType::FP32); + ret = ret_derivative_.divide_i(divider); + ret_derivative.copyData(ret_derivative_); + } else { + ret = ret_derivative.divide_i(divider); + } + + if (ret != ML_ERROR_NONE) { throw std::runtime_error( "[MSELossLayer::calcDerivative] Error when calculating loss"); } diff --git a/nntrainer/layers/lstm.cpp b/nntrainer/layers/lstm.cpp index d5f13a1fc5..be313a0aca 100644 --- a/nntrainer/layers/lstm.cpp +++ b/nntrainer/layers/lstm.cpp @@ -509,21 +509,27 @@ void LSTMLayer::finalize(InitLayerContext &context) { } // hidden_state_dim : [ batch_size, 1, max_timestep, unit ] - const TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit, - weight_tensor_type); + TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit, + weight_tensor_type); + hidden_state_dim.setDataType(context.getActivationDataType()); + wt_idx[LSTMParams::hidden_state] = context.requestTensor( hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); // cell_state_dim : [ batch_size, 1, max_timestep, unit ] - const TensorDim cell_state_dim(batch_size, 1, max_timestep, unit, - weight_tensor_type); + TensorDim cell_state_dim(batch_size, 1, max_timestep, unit, + weight_tensor_type); + cell_state_dim.setDataType(context.getActivationDataType()); + wt_idx[LSTMParams::cell_state] = context.requestTensor( cell_state_dim, "cell_state", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); // ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ] - const TensorDim ifgo_dim(batch_size, 1, max_timestep, NUM_GATE * unit, - weight_tensor_type); + TensorDim ifgo_dim(batch_size, 1, max_timestep, NUM_GATE * unit, + weight_tensor_type); + ifgo_dim.setDataType(context.getActivationDataType()); + wt_idx[LSTMParams::ifgo] = context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); @@ -576,21 +582,27 @@ void LSTMLayer::finalize(InitLayerContext &context) { } // reverse_hidden_state_dim : [ batch_size, 1, max_timestep, unit ] - const TensorDim reverse_hidden_state_dim(batch_size, 1, max_timestep, unit, - weight_tensor_type); + TensorDim reverse_hidden_state_dim(batch_size, 1, max_timestep, unit, + weight_tensor_type); + reverse_hidden_state_dim.setDataType(context.getActivationDataType()); + wt_idx[LSTMParams::reverse_hidden_state] = context.requestTensor( reverse_hidden_state_dim, "reverse_hidden_state", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); // reverse_cell_state_dim : [ batch_size, 1, max_timestep, unit ] - const TensorDim reverse_cell_state_dim(batch_size, 1, max_timestep, unit, - weight_tensor_type); + TensorDim reverse_cell_state_dim(batch_size, 1, max_timestep, unit, + weight_tensor_type); + reverse_cell_state_dim.setDataType(context.getActivationDataType()); + wt_idx[LSTMParams::reverse_cell_state] = context.requestTensor( reverse_cell_state_dim, "reverse_cell_state", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); // reverse_ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ] - const TensorDim reverse_ifgo_dim(batch_size, 1, max_timestep, - NUM_GATE * unit, weight_tensor_type); + TensorDim reverse_ifgo_dim(batch_size, 1, max_timestep, NUM_GATE * unit, + weight_tensor_type); + reverse_ifgo_dim.setDataType(context.getActivationDataType()); + wt_idx[LSTMParams::reverse_ifgo] = context.requestTensor( reverse_ifgo_dim, "reverse_ifgo", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); @@ -598,8 +610,10 @@ void LSTMLayer::finalize(InitLayerContext &context) { if (dropout_rate > epsilon) { // dropout_mask_dim = [ batch, 1, time_iteration, unit ] - const TensorDim dropout_mask_dim(batch_size, 1, max_timestep, unit, - weight_tensor_type); + TensorDim dropout_mask_dim(batch_size, 1, max_timestep, unit, + weight_tensor_type); + dropout_mask_dim.setDataType(context.getActivationDataType()); + wt_idx[LSTMParams::dropout_mask] = context.requestTensor( dropout_mask_dim, "dropout_mask", Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN); diff --git a/nntrainer/layers/lstm.h b/nntrainer/layers/lstm.h index f35fdf8815..a9b2cac7d7 100644 --- a/nntrainer/layers/lstm.h +++ b/nntrainer/layers/lstm.h @@ -99,7 +99,6 @@ class LSTMLayer : public LSTMCore { inline static const std::string type = "lstm"; -private: static constexpr unsigned int NUM_GATE = 4; /** common properties like Unit, IntegrateBias, HiddenStateActivation and diff --git a/nntrainer/layers/pooling2d_layer.cpp b/nntrainer/layers/pooling2d_layer.cpp index a68e42e8d0..b53ca354f2 100644 --- a/nntrainer/layers/pooling2d_layer.cpp +++ b/nntrainer/layers/pooling2d_layer.cpp @@ -6,6 +6,7 @@ * @date 12 June 2020 * @see https://github.com/nnstreamer/nntrainer * @author Jijoong Moon + * @author Jiho Chu * @bug No known bugs except for NYI items * @brief This is 2 Dimensional Pooling Layer Class for Neural Network * @@ -26,6 +27,13 @@ namespace nntrainer { static constexpr size_t SINGLE_INOUT_IDX = 0; +/** + * @brief help function for Pooling handler + */ +template struct PoolFunc { + typedef std::function Type; +}; + Pooling2DLayer::Pooling2DLayer( const std::array &padding_) : Layer(), @@ -96,6 +104,7 @@ void Pooling2DLayer::finalize(InitLayerContext &context) { out_dim.channel(in_dim.channel()); out_dim.height((eff_in_height - pool_size[0]) / stride[0] + 1); out_dim.width((eff_in_width - pool_size[1]) / stride[1] + 1); + out_dim.setDataType(in_dim.getDataType()); context.setOutputDimensions({out_dim}); /** @@ -111,13 +120,17 @@ void Pooling2DLayer::finalize(InitLayerContext &context) { * // clang-format on */ if (pooling_type == props::PoolingTypeInfo::Enum::global_max) { + auto helper_dim = in_dim; + helper_dim.setDataType(ml::train::TensorDim::DataType::FP32); pool_helper_idx = - context.requestTensor(in_dim, "helper_idx", Tensor::Initializer::NONE, + context.requestTensor(helper_dim, "helper_idx", Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN); - pool_helper_size.resize(in_dim.batch() * in_dim.channel()); + pool_helper_size.resize(helper_dim.batch() * helper_dim.channel()); } else { + auto helper_dim = out_dim; + helper_dim.setDataType(ml::train::TensorDim::DataType::FP32); pool_helper_idx = - context.requestTensor(out_dim, "helper_idx", Tensor::Initializer::NONE, + context.requestTensor(helper_dim, "helper_idx", Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN); } } @@ -172,15 +185,13 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { unsigned int J, K; result.setZero(); - float *result_data = result.getData(); unsigned int out_map_size = deriv.height() * deriv.width(); unsigned int in_map_size = height * width; - switch (pooling_type) { - case props::PoolingTypeInfo::Enum::max: { + auto apply_max = [&](T *result_data) { const int *iter = pool_helper.getData(); - const float *deriv_data = deriv.getData(); + const T *deriv_data = deriv.getData(); for (unsigned int b = 0; b < batch; ++b) { for (unsigned int c = 0; c < channel; ++c) { for (unsigned int i = 0; i < out_map_size; ++i) { @@ -195,9 +206,9 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { result_data += in_map_size; } } - } break; - case props::PoolingTypeInfo::Enum::global_average: - case props::PoolingTypeInfo::Enum::average: { + }; + + auto apply_average = [&](T *result_data) { int height_stride_end = height - p_height + pt; int width_stride_end = width - p_width + pl; const int *iter = pool_helper.getData(); @@ -207,7 +218,7 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { for (int j = -pt; j <= height_stride_end; j += stride[0]) { K = 0; for (int k = -pl; k <= width_stride_end; k += stride[1]) { - float del = deriv.getValue(b, i, J, K) / *iter; + T del = deriv.getValue(b, i, J, K) / *iter; int patch_height_end = std::min(static_cast(j + p_height), height); int patch_width_end = @@ -217,7 +228,7 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { for (int h = start_h; h < patch_height_end; ++h) { for (int w = start_w; w < patch_width_end; ++w) { result.setValue(b, i, h, w, - result.getValue(b, i, h, w) + del); + result.getValue(b, i, h, w) + del); } } iter++; @@ -227,15 +238,16 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { } } } - } break; - case props::PoolingTypeInfo::Enum::global_max: { - const float *deriv_data = deriv.getData(); + }; + + auto apply_global_max = [&](T *result_data) { + const T *deriv_data = deriv.getData(); for (unsigned int b = 0; b < batch; b++) { for (unsigned int c = 0; c < channel; c++) { const int *iter = pool_helper.getData() + pool_helper.getIndex(b, c, 0, 0); unsigned int helper_size = pool_helper_size[b * channel + c]; - float der = *deriv_data / helper_size; + T der = *deriv_data / static_cast(helper_size); for (unsigned int idx = 0; idx < helper_size; idx++) result_data[iter[idx]] += der; @@ -244,7 +256,40 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { result_data += in_map_size; } } - } break; + }; + + switch (pooling_type) { + case props::PoolingTypeInfo::Enum::max: + if (in_dim.getDataType() == ml::train::TensorDim::DataType::FP32) + apply_max(result.getData()); +#ifdef ENABLE_FP16 + else if (in_dim.getDataType() == ml::train::TensorDim::DataType::FP16) + apply_max(result.getData<_FP16>()); +#endif + else + throw std::runtime_error("Not supported datatype"); + break; + case props::PoolingTypeInfo::Enum::global_average: + case props::PoolingTypeInfo::Enum::average: + if (in_dim.getDataType() == ml::train::TensorDim::DataType::FP32) + apply_average(result.getData()); +#ifdef ENABLE_FP16 + else if (in_dim.getDataType() == ml::train::TensorDim::DataType::FP16) + apply_average(result.getData<_FP16>()); +#endif + else + throw std::runtime_error("Not supported datatype"); + break; + case props::PoolingTypeInfo::Enum::global_max: + if (in_dim.getDataType() == ml::train::TensorDim::DataType::FP32) + apply_global_max(result.getData()); +#ifdef ENABLE_FP16 + else if (in_dim.getDataType() == ml::train::TensorDim::DataType::FP16) + apply_global_max(result.getData<_FP16>()); +#endif + else + throw std::runtime_error("Not supported datatype"); + break; default: throw std::runtime_error("Error: Unknown Pooling Type"); } @@ -290,124 +335,167 @@ void Pooling2DLayer::pooling2d(Tensor &in, bool training, Tensor &output, * @param start_w (width index pointing the start of the patch) * @return result value of pooling */ - std::function pool_fn; + PoolFunc::Type pool_fn_fp32; +#ifdef ENABLE_FP16 + PoolFunc<_FP16>::Type pool_fn_fp16; +#endif unsigned int max_idx_count = 0; - switch (pooling_type) { - case props::PoolingTypeInfo::Enum::max: { - pool_fn = [&](const float *in_data, int channel_idx, int start_h, - int start_w) { - int end_h = start_h + patch_height; - int end_w = start_w + patch_width; - - float max_val = std::numeric_limits::lowest(); - - int cur_max_idx = -1; - int eff_end_h = std::min(end_h, in_height); - int eff_end_w = std::min(end_w, in_width); - start_w = std::max(0, start_w); - for (int h = std::max(0, start_h); h < eff_end_h; ++h) { - for (int w = start_w; w < eff_end_w; ++w) { - int cur_idx = h * in_width + w; - float val = in_data[cur_idx]; - if (max_val < val) { - max_val = val; - if (training) { - cur_max_idx = cur_idx; - } + + auto pool_fn_max = [&](const T *in_data, int channel_idx, + int start_h, int start_w) { + int end_h = start_h + patch_height; + int end_w = start_w + patch_width; + + T max_val = std::numeric_limits::lowest(); + + int cur_max_idx = -1; + int eff_end_h = std::min(end_h, in_height); + int eff_end_w = std::min(end_w, in_width); + start_w = std::max(0, start_w); + for (int h = std::max(0, start_h); h < eff_end_h; ++h) { + for (int w = start_w; w < eff_end_w; ++w) { + int cur_idx = h * in_width + w; + T val = in_data[cur_idx]; + if (max_val < val) { + max_val = val; + if (training) { + cur_max_idx = cur_idx; } } } + } - if (training) { - pool_helper.setValueInt(max_idx_count++, cur_max_idx); - } + if (training) { + pool_helper.setValueInt(max_idx_count++, cur_max_idx); + } - return max_val; - }; - break; - } - case props::PoolingTypeInfo::Enum::global_max: { - pool_fn = [&, this](const float *in_data, int channel_idx, int start_h, - int start_w) { - int end_h = start_h + patch_height; - int end_w = start_w + patch_width; - - float max_val = std::numeric_limits::lowest(); - int *helper_data = pool_helper.getData(); - helper_data += channel_idx * in_height * in_width; - - for (int h = start_h; h < end_h; ++h) { - for (int w = start_w; w < end_w; ++w) { - int cur_idx = h * in_width + w; - float val = in_data[cur_idx]; - if (max_val < val) { - max_val = val; - max_idx_count = 0; - } + return max_val; + }; - if (training && max_val == val) { - *(helper_data + max_idx_count++) = cur_idx; - } + auto pool_fn_global_max = [&, this](const T *in_data, + int channel_idx, int start_h, + int start_w) { + int end_h = start_h + patch_height; + int end_w = start_w + patch_width; + + T max_val = std::numeric_limits::lowest(); + int *helper_data = pool_helper.getData(); + helper_data += channel_idx * in_height * in_width; + + for (int h = start_h; h < end_h; ++h) { + for (int w = start_w; w < end_w; ++w) { + int cur_idx = h * in_width + w; + T val = in_data[cur_idx]; + if (max_val < val) { + max_val = val; + max_idx_count = 0; } - } - pool_helper_size[batch_idx * in.channel() + channel_idx] = max_idx_count; - return max_val; - }; - break; - } - case props::PoolingTypeInfo::Enum::global_average: - case props::PoolingTypeInfo::Enum::average: { - pool_fn = [&](const float *in_data, int channel_idx, int start_h, - int start_w) { - int end_h = start_h + patch_height; - int end_w = start_w + patch_width; - float total = 0.0f; - - int eff_end_h = std::min(end_h, in_height); - int eff_end_w = std::min(end_w, in_width); - int eff_start_h = std::max(0, start_h); - int eff_start_w = std::max(0, start_w); - - int cnt = (eff_end_h - eff_start_h) * (eff_end_w - eff_start_w); - for (int h = eff_start_h; h < eff_end_h; ++h) { - for (int w = eff_start_w; w < eff_end_w; ++w) { - float val = in_data[h * in_width + w]; - total += val; + if (training && max_val == val) { + *(helper_data + max_idx_count++) = cur_idx; } } + } - if (training) { - pool_helper.setValueInt(max_idx_count++, cnt); + pool_helper_size[batch_idx * in.channel() + channel_idx] = max_idx_count; + return max_val; + }; + + auto pool_fn_average = [&](const T *in_data, int channel_idx, + int start_h, int start_w) { + int end_h = start_h + patch_height; + int end_w = start_w + patch_width; + T total = static_cast(0.0f); + + int eff_end_h = std::min(end_h, in_height); + int eff_end_w = std::min(end_w, in_width); + int eff_start_h = std::max(0, start_h); + int eff_start_w = std::max(0, start_w); + + int cnt = (eff_end_h - eff_start_h) * (eff_end_w - eff_start_w); + for (int h = eff_start_h; h < eff_end_h; ++h) { + for (int w = eff_start_w; w < eff_end_w; ++w) { + T val = in_data[h * in_width + w]; + total += val; } - return total / cnt; - }; + } + + if (training) { + pool_helper.setValueInt(max_idx_count++, cnt); + } + return total / cnt; + }; + + switch (pooling_type) { + case props::PoolingTypeInfo::Enum::max: + pool_fn_fp32 = pool_fn_max; +#ifdef ENABLE_FP16 + pool_fn_fp16 = pool_fn_max; +#endif + break; + case props::PoolingTypeInfo::Enum::global_max: + pool_fn_fp32 = pool_fn_global_max; +#ifdef ENABLE_FP16 + pool_fn_fp16 = pool_fn_global_max; +#endif + break; + case props::PoolingTypeInfo::Enum::global_average: + case props::PoolingTypeInfo::Enum::average: + pool_fn_fp32 = pool_fn_average; +#ifdef ENABLE_FP16 + pool_fn_fp16 = pool_fn_average; +#endif break; - } case props::PoolingTypeInfo::Enum::unknown: default: throw std::invalid_argument("unknown pooling type given"); break; } - const float *in_data = in.getData(); - float *out_data = output.getData(); - - unsigned int map_size = in_height * in_width; - - int height_stride_end = height - patch_height - pt; - int width_stride_end = width - patch_width - pl; - for (unsigned int i = 0; i < channel; ++i) { - const float *in_data_channel_sliced = in_data + i * map_size; - for (int j = -pt; j <= height_stride_end; j += stride[0]) { - for (int k = -pl; k <= width_stride_end; k += stride[1]) { - float pool_value = pool_fn(in_data_channel_sliced, i, j, k); - *out_data = pool_value; - out_data++; + if (in.getDataType() == ml::train::TensorDim::DataType::FP32) { + const float *in_data = in.getData(); + float *out_data = output.getData(); + + unsigned int map_size = in_height * in_width; + + int height_stride_end = height - patch_height - pt; + int width_stride_end = width - patch_width - pl; + for (unsigned int i = 0; i < channel; ++i) { + const float *in_data_channel_sliced = in_data + i * map_size; + for (int j = -pt; j <= height_stride_end; j += stride[0]) { + for (int k = -pl; k <= width_stride_end; k += stride[1]) { + float pool_value = pool_fn_fp32(in_data_channel_sliced, i, j, k); + *out_data = pool_value; + out_data++; + } + } + } + } +#ifdef ENABLE_FP16 + else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) { + const _FP16 *in_data = in.getData<_FP16>(); + _FP16 *out_data = output.getData<_FP16>(); + + unsigned int map_size = in_height * in_width; + + int height_stride_end = height - patch_height - pt; + int width_stride_end = width - patch_width - pl; + for (unsigned int i = 0; i < channel; ++i) { + const _FP16 *in_data_channel_sliced = in_data + i * map_size; + for (int j = -pt; j <= height_stride_end; j += stride[0]) { + for (int k = -pl; k <= width_stride_end; k += stride[1]) { + _FP16 pool_value = pool_fn_fp16(in_data_channel_sliced, i, j, k); + *out_data = pool_value; + out_data++; + } } } } +#endif + else { + throw std::runtime_error("Not supported datatype"); + } } void Pooling2DLayer::setBatch(RunLayerContext &context, unsigned int batch) { diff --git a/nntrainer/layers/reshape_layer.cpp b/nntrainer/layers/reshape_layer.cpp index 0f82d84f3a..07564b3970 100644 --- a/nntrainer/layers/reshape_layer.cpp +++ b/nntrainer/layers/reshape_layer.cpp @@ -42,6 +42,7 @@ void ReshapeLayer::finalize(InitLayerContext &context) { } out_dim.batch(in_dim.batch()); + out_dim.setDataType(in_dim.getDataType()); context.setOutputDimensions({out_dim}); } From 6e89fe65641918d8f32e3f218cff600eb92f8ea4 Mon Sep 17 00:00:00 2001 From: Jiho Chu Date: Tue, 19 Mar 2024 16:48:06 +0900 Subject: [PATCH 8/8] [Mixed] Reset for invalid values It may get an invalid value for both internal tensor or gradient. This patch checks the validation of the data, and fix for it. Also, sscal api is replace with scopy for setZero, because it produces the invalid value if invalid input value is used. Signed-off-by: Jiho Chu --- nntrainer/graph/network_graph.cpp | 16 +++++++++++++++- nntrainer/layers/bn_layer.cpp | 5 +++++ nntrainer/tensor/tensor.cpp | 9 ++++++--- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index e98d6b7b17..77f0b2933c 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -455,6 +455,18 @@ void NetworkGraph::backwarding( loss_scale = scale; }; + auto check_weights = [](std::vector &weights) { + bool valid = true; + for (auto &w : weights) { + auto grad = w->getGradient(); + if (grad.checkDataValidation(false) == false) { + grad.setZero(); + valid = false; + } + } + return valid; + }; + // check first layer's derivative is valid // loss scale is adjusted between 1.0f ~ 256.0f // @todo provide max scale property @@ -465,13 +477,15 @@ void NetworkGraph::backwarding( ml_logd( "Derivative validation failed. Skip applying gradient. loss_scale(%f)", scale); + check_weights(clip_weights); update_loss_scale(scale); return; } else { for (unsigned int idx = 0; idx < clip_weights.size(); idx++) { auto const &w = clip_weights[idx]; w->applyScaler(loss_scale); - if (w->getGradient().checkDataValidation(false) == false) { + + if (!check_weights(clip_weights)) { float scale = loss_scale > 1.5f ? loss_scale - 0.5f : 1.0f; ml_loge("gradient validation failed. skip update. loss_scale(%f)", scale); diff --git a/nntrainer/layers/bn_layer.cpp b/nntrainer/layers/bn_layer.cpp index e978b1ef59..3ca7628a3a 100644 --- a/nntrainer/layers/bn_layer.cpp +++ b/nntrainer/layers/bn_layer.cpp @@ -182,6 +182,11 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context, Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]); if (training) { + t_reduced.setZero(); + deviation.setZero(); + invstd.setZero(); + cvar.setZero(); + input_.average(axes_to_reduce, t_reduced); input_.subtract(t_reduced, deviation); diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index d01cd26378..01cae9edef 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -3329,10 +3329,13 @@ void Tensor::setZero() { apply_i([](float val) -> float { return 0; }); } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 - if (contiguous) - sscal(size(), 0, getData<_FP16>(), 1); - else + if (contiguous) { + _FP16 zero = (_FP16)0.0f; + scopy(size(), &zero, 0, getData<_FP16>(), 1, + ml::train::TensorDim::DataType::FP16); + } else { apply_i<_FP16>([](_FP16 val) -> _FP16 { return 0; }); + } #else throw std::invalid_argument("Error: enable-fp16 is not enabled"); #endif