From cd17a66b8eec45a2e91e543a884722763b5e608e Mon Sep 17 00:00:00 2001 From: "jijoong.moon" Date: Mon, 11 Nov 2024 10:35:10 +0900 Subject: [PATCH] [ FSU ] Enabls Asynchronos FSU for forwarding This PR enables asynchronos mode for FSU (flash storage utilization) for better performance. It splits the load and unload tensors which make difficult to handle. Also fix the inference execution order when it is in INFERENCE mode and change the trainable option to false when it calls the request weights and tensors. Add the new function to load and unload tensors as well as check load complete. It also considers weight pool and tensor pool differenetly according to the ExecutionMode. It is not use FSU mode for tensor pool for the INFERENCE Mode. Resolves: **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: jijoong.moon --- Applications/SimpleFC/README.md | 30 ----- Applications/SimpleFC/jni/main.cpp | 80 ++++++------ Applications/SimpleFC/jni/meson.build | 12 +- meson.build | 1 - nntrainer/graph/network_graph.cpp | 32 ++++- nntrainer/graph/network_graph.h | 36 +++++- nntrainer/layers/layer_node.cpp | 7 +- nntrainer/models/model_common_properties.cpp | 2 - nntrainer/models/model_common_properties.h | 18 --- nntrainer/models/neuralnet.cpp | 100 ++++++++++++--- nntrainer/models/neuralnet.h | 6 +- nntrainer/tensor/cache_elem.cpp | 1 - nntrainer/tensor/cache_loader.cpp | 59 +++++++-- nntrainer/tensor/cache_loader.h | 26 +++- nntrainer/tensor/cache_pool.cpp | 2 +- nntrainer/tensor/manager.cpp | 128 ++++++++++++++++++- nntrainer/tensor/manager.h | 67 +++++++++- nntrainer/tensor/task_executor.cpp | 15 +-- nntrainer/tensor/task_executor.h | 1 + nntrainer/tensor/tensor_pool.cpp | 12 +- nntrainer/tensor/tensor_pool.h | 9 ++ 21 files changed, 487 insertions(+), 157 deletions(-) delete mode 100644 Applications/SimpleFC/README.md diff --git a/Applications/SimpleFC/README.md b/Applications/SimpleFC/README.md deleted file mode 100644 index f195a8c764..0000000000 --- a/Applications/SimpleFC/README.md +++ /dev/null @@ -1,30 +0,0 @@ -# Resnet with cifar100 - -This application contains a Resnet18 model and a trainer with cifar100. - -Reference. [Kaiming He. 2015](https://arxiv.org/abs/1512.03385) -Reference. [Learning Multiple Layers of Features from Tiny Images, Alex Krizhevsky, 2009](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf) - -## How to run a train epoch - -### To simply run with a fake data. - -Once you compile, with `meson`, you can run with `meson test app_resnet18`. -Please file an issue if you have a problem running the example. - -```bash -$ meson ${build_dir} -Denable-test=true -Denable-long-test=true -$ meson test app_resnet18 -v -C ${build_dir} -``` - -### To run with a real data. - -```bash -$ meson ${build_dir} build -$ ${project_dir}/Applications/Resnet/res/prepare_dataset.sh ${dataset_download_dir} # this is to download raw data of cifar100 -$ OPENBLAS_NUM_THREADS=1 ${build_dir}/Applications/Resnet/jni/nntrainer_resnet18 \ - ${dataset_download_dir}/cifar-100-binary \ - ${batch_size} \ - ${data_split} \ - ${epoch} -``` diff --git a/Applications/SimpleFC/jni/main.cpp b/Applications/SimpleFC/jni/main.cpp index ca29eb6a46..6ae0b4e4cb 100644 --- a/Applications/SimpleFC/jni/main.cpp +++ b/Applications/SimpleFC/jni/main.cpp @@ -1,13 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 /** - * Copyright (C) 2020 Jihoon Lee + * Copyright (C) 2024 Jijoong Moon * * @file main.cpp - * @date 24 Jun 2021 - * @todo move resnet model creating to separate sourcefile - * @brief task runner for the resnet + * @date 10 Dec 2024 + * @brief Test Application for Asynch FSU * @see https://github.com/nnstreamer/nntrainer - * @author Jihoon Lee + * @author Jijoong Moon * @bug No known bugs except for NYI items */ #include @@ -76,9 +75,9 @@ static std::string withKey(const std::string &key, } /** - * @brief Create resnet 18 + * @brief Create network * - * @return vector of layers that contain full graph of resnet18 + * @return vector of layers that contain full graph of asynch */ std::vector createGraph() { using ml::train::createLayer; @@ -86,22 +85,22 @@ std::vector createGraph() { std::vector layers; layers.push_back(createLayer( - "input", {withKey("name", "input0"), withKey("input_shape", "1:1:32")})); + "input", {withKey("name", "input0"), withKey("input_shape", "1:1:320")})); - layers.push_back( - createLayer("fully_connected", - {withKey("unit", 10)})); + layers.push_back(createLayer("fully_connected", + {withKey("unit", 100), + withKey("weight_initializer", "xavier_uniform"), + withKey("bias_initializer", "zeros")})); - layers.push_back( - createLayer("fully_connected", - {withKey("unit", 10)})); + layers.push_back(createLayer("fully_connected", + {withKey("unit", 100), + withKey("weight_initializer", "xavier_uniform"), + withKey("bias_initializer", "zeros")})); return layers; } -/// @todo update createResnet18 to be more generic ModelHandle create() { - /// @todo support "LOSS : cross" for TF_Lite Exporter ModelHandle model = ml::train::createModel(ml::train::ModelType::NEURAL_NET, {withKey("loss", "mse")}); @@ -126,17 +125,19 @@ int validData_cb(float **input, float **label, bool *last, void *user_data) { return 0; } -/// @todo maybe make num_class also a parameter void createAndRun(unsigned int epochs, unsigned int batch_size, UserDataType &train_user_data, UserDataType &valid_user_data) { // setup model ModelHandle model = create(); - model->setProperty({withKey("batch_size", batch_size), - withKey("epochs", epochs), - withKey("save_path", "model_full.bin"), - withKey("memory_swap","true")}); + model->setProperty( + {withKey("batch_size", batch_size), withKey("epochs", epochs), + // withKey("save_path", "model_full.bin")}); + // withKey("save_path", "model_full.bin"), withKey("memory_swap", + // "true")}); + withKey("memory_swap", "true"), withKey("memory_swap_lookahead", "1"), + withKey("model_tensor_type", "FP16-FP16")}); auto optimizer = ml::train::createOptimizer("sgd", {"learning_rate=0.001"}); model->setOptimizer(std::move(optimizer)); @@ -156,28 +157,24 @@ void createAndRun(unsigned int epochs, unsigned int batch_size, auto dataset_valid = ml::train::createDataset( ml::train::DatasetType::GENERATOR, validData_cb, valid_user_data.get()); - // model->setDataset(ml::train::DatasetModeType::MODE_TRAIN, - // std::move(dataset_train)); - // model->setDataset(ml::train::DatasetModeType::MODE_VALID, - // std::move(dataset_valid)); - - // if (transfer_learning) - // model->load(pretrained_bin_path); - // model->train(); + // to test asynch fsu, we do need save the model weight data in file + model->save("simplefc_weight_fp16_fp16_100.bin", + ml::train::ModelFormat::MODEL_FORMAT_BIN); + model->load("./simplefc_weight_fp16_fp16_100.bin"); model->summarize(std::cout, ML_TRAIN_SUMMARY_MODEL); - uint feature_size = 32; + uint feature_size = 320; - float input [32]; - float label [1]; + float input[320]; + float label[1]; - for(uint j=0;j in; - std::vector l; - std::vector answer; + std::vector in; + std::vector l; + std::vector answer; in.push_back(input); l.push_back(label); @@ -187,8 +184,7 @@ void createAndRun(unsigned int epochs, unsigned int batch_size, in.clear(); l.clear(); - std::cout << "done"< @@ -196,10 +192,10 @@ createFakeDataGenerator(unsigned int batch_size, unsigned int simulated_data_size, unsigned int data_split) { UserDataType train_data(new nntrainer::util::RandomDataLoader( - {{batch_size, 1, 1, 32}}, {{batch_size, 1, 1, 10}}, + {{batch_size, 1, 1, 320}}, {{batch_size, 1, 1, 100}}, simulated_data_size / data_split)); UserDataType valid_data(new nntrainer::util::RandomDataLoader( - {{batch_size, 1, 1, 32}}, {{batch_size, 1, 1, 10}}, + {{batch_size, 1, 1, 320}}, {{batch_size, 1, 1, 100}}, simulated_data_size / data_split)); return {std::move(train_data), std::move(valid_data)}; @@ -231,9 +227,9 @@ int main(int argc, char *argv[]) { std::string data_dir = "fake"; uint batch_size = 1; - uint data_split =1; + uint data_split = 1; uint epoch = 1; - + std::array user_datas; try { diff --git a/Applications/SimpleFC/jni/meson.build b/Applications/SimpleFC/jni/meson.build index 2033e858c8..6743f03809 100644 --- a/Applications/SimpleFC/jni/meson.build +++ b/Applications/SimpleFC/jni/meson.build @@ -1,22 +1,22 @@ -resnet_sources = [ +app_sources = [ 'main.cpp', cifar_path / 'cifar_dataloader.cpp' ] -resnet_dependencies = [app_utils_dep, +app_dependencies = [app_utils_dep, iniparser_dep, nntrainer_dep, nntrainer_ccapi_dep ] if get_option('enable-test') - resnet_dependencies += [gtest_dep] + app_dependencies += [gtest_dep] endif e = executable('nntrainer_simplefc', - resnet_sources, + app_sources, include_directories: [include_directories('.'), cifar_include_dir], - dependencies: resnet_dependencies, + dependencies: app_dependencies, install: get_option('install-app'), install_dir: application_install_dir ) @@ -24,5 +24,5 @@ e = executable('nntrainer_simplefc', if get_option('enable-long-test') testenv = environment() testenv.set('OPENBLAS_NUM_THREADS', '4') - test('app_resnet18', e, args: ['fake', '1', '128', '1'], env: testenv, timeout: 300) + test('app_asynch_fsu', e, args: ['fake', '1', '128', '1'], env: testenv, timeout: 300) endif diff --git a/meson.build b/meson.build index 150be5c4ac..815a47f576 100644 --- a/meson.build +++ b/meson.build @@ -90,7 +90,6 @@ if get_option('enable-fp16') extra_defines += '-DENABLE_FP16=1' extra_defines += '-DUSE__FP16=1' extra_defines += '-DUSE_NEON=1' - extra_defines += '-DUSE_MMAP=1' elif arch == 'aarch64' ## About FP16 in GCC (from GCC-9.1 manual) # https://gcc.gnu.org/onlinedocs/gcc-9.1.0/gcc/Half-Precision.html diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 479d7296da..9d568c324b 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -229,6 +229,14 @@ int NetworkGraph::checkCompiledGraph() { void NetworkGraph::markNodesForBackwarding() { /** accumulate all the nodes which must support backwarding */ std::unordered_set must_support_backwarding; + if (exec_mode == ExecutionMode::INFERENCE) { + for (auto iter = cbegin(); iter != cend(); iter++) { + auto lnode = (*iter); + lnode->needsCalcGradient(false); + lnode->needsCalcDerivative(false); + } + return; + } /** * if a node is trainable, then all the nodes ahead of it must support @@ -867,14 +875,16 @@ NetworkGraph::finalizeContext(const std::shared_ptr &lnode, } lnode->setDataType(init_context.getWeightDataType(), init_context.getActivationDataType()); - + bool trainable = lnode->getTrainable(); + if (exec_mode == ExecutionMode::INFERENCE) + trainable = false; lnode->configureRunContext( // TODO: update weights spec for trainable based on layer trainable prop tensor_manager->requestWeights(gnode, init_context.getWeightsSpec(), - lnode->getTrainable(), shared_weight_names), + trainable, shared_weight_names), inputs, outputs, tensor_manager->requestTensors(gnode, init_context.getTensorsSpec(), - lnode->getTrainable(), shared_tensor_names), + trainable, shared_tensor_names), init_context.getLossScale()); return outputs; @@ -1552,6 +1562,22 @@ void NetworkGraph::flushCacheExcept(unsigned int order) { tensor_manager->flushCacheExcept(order); } +void NetworkGraph::LoadTensors(unsigned int order) { + tensor_manager->LoadTensors(order); +} + +bool NetworkGraph::checkLoadComplete(unsigned int order) { + return tensor_manager->checkLoadComplete(order); +} + +bool NetworkGraph::checkUnloadComplete(unsigned int order) { + return tensor_manager->checkUnloadComplete(order); +} + +void NetworkGraph::UnloadTensors(unsigned int order) { + tensor_manager->UnloadTensors(order); +} + void NetworkGraph::requestOptimizerVariable( std::function(const TensorDim &)> cb, bool request_only_trainable) { diff --git a/nntrainer/graph/network_graph.h b/nntrainer/graph/network_graph.h index 05aeae9193..078268d7b2 100644 --- a/nntrainer/graph/network_graph.h +++ b/nntrainer/graph/network_graph.h @@ -370,8 +370,12 @@ class NetworkGraph { * @brief Allocate memory for all the managed weights */ void allocateWeights(bool init = true) { - tensor_manager->allocateWeights( - std::get<3>(backward_iter_end->getExecutionOrder()), init); + unsigned int max_exec_order = + std::get<3>(backward_iter_end->getExecutionOrder()); + + if (exec_mode == ExecutionMode::INFERENCE) + max_exec_order = std::get<0>(forward_iter_end->getExecutionOrder()); + tensor_manager->allocateWeights(max_exec_order, init); } /** @@ -447,6 +451,34 @@ class NetworkGraph { */ void flushCacheExcept(const unsigned int order); + /** + * @brief Load data of order to the device + * + * @param order execution order + */ + void LoadTensors(const unsigned int order); + + /** + * @brief check data of order is loaded + * + * @param order execution order + */ + bool checkLoadComplete(const unsigned int order); + + /** + * @brief check data of order is Unloaded + * + * @param order execution order + */ + bool checkUnloadComplete(const unsigned int order); + + /** + * @brief Load data of order to the device + * + * @param order execution order + */ + void UnloadTensors(const unsigned int order); + #ifdef ENABLE_TEST /** * @brief Get layer node's tenexecution orders diff --git a/nntrainer/layers/layer_node.cpp b/nntrainer/layers/layer_node.cpp index eed9398094..9c6c290703 100644 --- a/nntrainer/layers/layer_node.cpp +++ b/nntrainer/layers/layer_node.cpp @@ -489,7 +489,8 @@ void LayerNode::read(std::ifstream &file, bool opt_var, for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) { /// @note shared weights are only be read at the first acecss - if (run_context->isGradientLastAccess(i)) { + // if (run_context->isGradientLastAccess(i)) { + if (run_context->isGradientFirstAccess(i)) { if (layer->getType() == BatchNormalizationLayer::type) { if ((mode == ml::train::ExecutionMode::TRAIN) && (this->getWeightDataType() != TensorDim::DataType::FP32)) { @@ -526,7 +527,7 @@ void LayerNode::save(std::ofstream &file, bool opt_var, if (opt_var) { for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) { - if (run_context->isGradientLastAccess(i) && getTrainable()) { + if (run_context->isGradientFirstAccess(i) && getTrainable()) { // @note save optimizer variables if (run_context->weightHasGradient(i)) { for (unsigned int j = 0; j < run_context->getNumWeightOptVar(i); @@ -539,7 +540,7 @@ void LayerNode::save(std::ofstream &file, bool opt_var, } else { // @note shared weights are only be saved at the first access for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) { - if (run_context->isGradientLastAccess(i)) { + if (run_context->isGradientFirstAccess(i)) { /** @note For batch normalization layer, we do need full precision for * training and the data type of weight is full precision. But for diff --git a/nntrainer/models/model_common_properties.cpp b/nntrainer/models/model_common_properties.cpp index b9589c67f6..9fc78d5602 100644 --- a/nntrainer/models/model_common_properties.cpp +++ b/nntrainer/models/model_common_properties.cpp @@ -34,8 +34,6 @@ MemorySwap::MemorySwap(bool value) { set(value); } MemorySwapPath::MemorySwapPath(const std::string &value) { set(value); } -MemorySwapMode::MemorySwapMode(const std::string &value) { set(value); } - MemorySwapLookahead::MemorySwapLookahead(const unsigned int &value) { set(value); } diff --git a/nntrainer/models/model_common_properties.h b/nntrainer/models/model_common_properties.h index e66ecb9844..3bd8d9078b 100644 --- a/nntrainer/models/model_common_properties.h +++ b/nntrainer/models/model_common_properties.h @@ -179,24 +179,6 @@ class MemorySwapLookahead : public Property { MemorySwapLookahead(const unsigned int &value = 0); }; -/** - * @brief cache file path property - * - */ -class MemorySwapMode : public Property { -public: - static constexpr const char *key = - "memory_swap_mode"; /**< unique key to access */ - using prop_tag = str_prop_tag; /**< property type */ - - /** - * @brief Constructor - * - * @param value value to set, defaults to current directory - */ - MemorySwapMode(const std::string &value = "train"); -}; - /** * @brief Enumeration of Data Type for model & layer */ diff --git a/nntrainer/models/neuralnet.cpp b/nntrainer/models/neuralnet.cpp index ce88593a70..bff18c2ddd 100644 --- a/nntrainer/models/neuralnet.cpp +++ b/nntrainer/models/neuralnet.cpp @@ -67,12 +67,11 @@ namespace nntrainer { NeuralNetwork::NeuralNetwork() : model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm(), props::LossScale()), - model_flex_props(props::Epochs(), props::TrainingBatchSize(), - props::SavePath(), props::ContinueTrain(), - props::SaveBestPath(), props::MemoryOptimization(), - props::MemorySwap(), props::MemorySwapPath(), - props::MemorySwapLookahead(), props::TensorFormat(), - props::ModelTensorDataType(), props::MemorySwapMode()), + model_flex_props( + props::Epochs(), props::TrainingBatchSize(), props::SavePath(), + props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(), + props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead(), + props::TensorFormat(), props::ModelTensorDataType()), load_path(std::string()), epoch_idx(0), iter(0), @@ -88,12 +87,11 @@ NeuralNetwork::NeuralNetwork() : NeuralNetwork::NeuralNetwork(AppContext app_context_) : model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm(), props::LossScale()), - model_flex_props(props::Epochs(), props::TrainingBatchSize(), - props::SavePath(), props::ContinueTrain(), - props::SaveBestPath(), props::MemoryOptimization(), - props::MemorySwap(), props::MemorySwapPath(), - props::MemorySwapLookahead(), props::TensorFormat(), - props::ModelTensorDataType(), props::MemorySwapMode()), + model_flex_props( + props::Epochs(), props::TrainingBatchSize(), props::SavePath(), + props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(), + props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead(), + props::TensorFormat(), props::ModelTensorDataType()), load_path(std::string()), epoch_idx(0), iter(0), @@ -102,6 +100,7 @@ NeuralNetwork::NeuralNetwork(AppContext app_context_) : initialized(false), compiled(false), loadedFromConfig(false), + exec_mode(ExecutionMode::TRAIN), app_context(app_context_) {} int NeuralNetwork::loadFromConfig(const std::string &config) { @@ -214,6 +213,18 @@ int NeuralNetwork::compile(ExecutionMode mode) { int NeuralNetwork::initialize(ExecutionMode mode) { int status = ML_ERROR_NONE; + if (mode != exec_mode) { + if (mode == ExecutionMode::INFERENCE) { + ml_logd("Execution mode mismatch : train mode @compile & inference mode " + "@ initialize"); + exec_mode = mode; + } else { + NNTR_THROW_IF(exec_mode == ExecutionMode::TRAIN, std::invalid_argument) + << "Execution mode mismatch : trying to train with compiled for " + "infence"; + } + } + if (initialized) { ml_loge("Error: Initializing the model again"); return ML_ERROR_NOT_SUPPORTED; @@ -244,7 +255,7 @@ int NeuralNetwork::initialize(ExecutionMode mode) { } status = model_graph.initialize( - mode, input_conn, + exec_mode, input_conn, std::vector(label_layers.begin(), label_layers.end())); NN_RETURN_STATUS(); @@ -267,6 +278,8 @@ int NeuralNetwork::initialize(ExecutionMode mode) { // Allocate weights model_graph.allocateWeights(exec_mode != ExecutionMode::INFERENCE); + // enable this to save initialized weights for INFERENCE + // model_graph.allocateWeights(true); initialized = true; @@ -329,6 +342,7 @@ NeuralNetwork::~NeuralNetwork() { */ sharedConstTensors NeuralNetwork::forwarding( bool training, std::function stop_cb, void *userdata) { + std::function, bool)> forwarding_op = [this, stop_cb, userdata](std::shared_ptr node, bool training) -> void { @@ -336,9 +350,65 @@ sharedConstTensors NeuralNetwork::forwarding( PROFILE_MEM_ANNOTATE("Forwarding for layer: " + node->getName()); auto f = std::get<0>(node->getExecutionOrder()); - model_graph.flushCacheExcept(f); - node->forwarding(training); + // temperally remain. when we evaluate all for asynch mode, we weill remove + if (exec_mode == ExecutionMode::TRAIN) { + model_graph.flushCacheExcept(f); + node->forwarding(training); + } else { + + /** + currently, it supports FSU asynch mode for inference. The prcedure of + FSU is below, + + Prerequests : This function is called node by node at the forwarding + function in network graph. + + Step 1. If the execution order is the first (f==0) then, it will try to + load tensors which used at layer 0. + + Step 2. It check whether these tensors from Step 1, then do the + forwarding of the first node. + + Step 3. Then check the look a head which says how many layer weights need + to be loaded before running to hide overehad due to FSU, + + Step 4. Try to get the tesors by asking tensors for layers which is done + by thread pool + + Step 5. Try to release the weights which has execution order less then f. + + Step n. repeat next layer starting with checking the tenosrs are loaded, + and if it is loaded, then run forwarding. Every time it finishes + the forwarding, ask load tensors for next n layers. + **/ + + if (f == 0) + model_graph.LoadTensors(f); + + if (model_graph.checkLoadComplete(f)) { + node->forwarding(training); + ml_logd("Forwarding is done %d : %s", f, node->getName().c_str()); + + unsigned int lookahead = + std::get(model_flex_props); + + if (lookahead != 0) { + if ((f) % (lookahead + 1) == lookahead - 1) { + std::cout << "request load tensor : " << f + lookahead + 1 + << std::endl; + ml_logd("request load tensor for %d", f + 1); + model_graph.LoadTensors((f / (lookahead + 1) + 1) * + (lookahead + 1)); + } + } else { + model_graph.LoadTensors(f); + } + + if (f != 0) + model_graph.UnloadTensors(f); + } + } }; return model_graph.forwarding(training, forwarding_op, stop_cb, userdata); diff --git a/nntrainer/models/neuralnet.h b/nntrainer/models/neuralnet.h index d964c60ebd..b79378138c 100644 --- a/nntrainer/models/neuralnet.h +++ b/nntrainer/models/neuralnet.h @@ -629,7 +629,7 @@ s * @retval shared_ptr props::Epochs, props::TrainingBatchSize, props::SavePath, props::ContinueTrain, props::SaveBestPath, props::MemoryOptimization, props::MemorySwap, props::MemorySwapPath, props::MemorySwapLookahead, - props::TensorFormat, props::ModelTensorDataType, props::MemorySwapMode>; + props::TensorFormat, props::ModelTensorDataType>; using RigidPropTypes = std::tuple, std::vector, props::ClipGradByGlobalNorm, @@ -674,6 +674,8 @@ s * @retval shared_ptr RunStats training; /** training statistics of the model */ RunStats testing; /** testing statistics of the model */ + ExecutionMode exec_mode; /** execution mode : train : inference */ + AppContext app_context; /** Configurations bound to current app */ NetworkGraph model_graph; /** Network Model Graph */ @@ -682,8 +684,6 @@ s * @retval shared_ptr DynamicTrainingOptimization dynamic_training_opt; /**< Dynamic fine-tuning optimization mode. supported modes are "max" and "norm" */ - ExecutionMode exec_mode; - /** * @brief save model in ini * diff --git a/nntrainer/tensor/cache_elem.cpp b/nntrainer/tensor/cache_elem.cpp index c7849d70af..070b621307 100644 --- a/nntrainer/tensor/cache_elem.cpp +++ b/nntrainer/tensor/cache_elem.cpp @@ -60,7 +60,6 @@ void CacheElem::swapIn(Options opt) { mem_data->setAddr((void *)buf); mem_data->setValid(true); active = true; - #ifdef PROFILE std::string msg("CacheElem("); msg += device->getDevicePath() + ") #" + std::to_string(id); diff --git a/nntrainer/tensor/cache_loader.cpp b/nntrainer/tensor/cache_loader.cpp index b97094a8fa..64df838984 100644 --- a/nntrainer/tensor/cache_loader.cpp +++ b/nntrainer/tensor/cache_loader.cpp @@ -27,23 +27,29 @@ namespace nntrainer { CacheLoader::CacheLoader(std::shared_ptr cache_pool) : pool(cache_pool), - task_executor(nullptr) {} + load_task_executor(nullptr), + unload_task_executor(nullptr) {} CacheLoader::~CacheLoader() { - if (task_executor) - delete task_executor; + if (load_task_executor) + delete load_task_executor; + if (unload_task_executor) + delete unload_task_executor; } void CacheLoader::init() { - if (task_executor) - return; - task_executor = new TaskExecutor(pool->getName()); + if (load_task_executor == nullptr) + load_task_executor = new TaskExecutor(pool->getName()); + if (unload_task_executor == nullptr) + unload_task_executor = new TaskExecutor(pool->getName()); } void CacheLoader::finish() { - delete task_executor; - task_executor = nullptr; + delete load_task_executor; + load_task_executor = nullptr; + delete unload_task_executor; + unload_task_executor = nullptr; } void CacheLoader::load(unsigned int order) { pool->loadExec(order); } @@ -56,7 +62,7 @@ int CacheLoader::loadAsync(unsigned int order, int CacheLoader::loadAsync(unsigned int order, TaskExecutor::CompleteCallback complete, long timeout_ms) { - if (!task_executor) { + if (!load_task_executor) { ml_loge("init is needed"); return ML_ERROR_INVALID_PARAMETER; } @@ -64,7 +70,7 @@ int CacheLoader::loadAsync(unsigned int order, Task::Work work = [&](std::atomic_bool &running, void *data) { unsigned int exe_order = (unsigned int)(std::uintptr_t)data; - pool->flushExcept({exe_order - 1, exe_order}); + // pool->flushExcept({exe_order - 1, exe_order}); pool->loadExec(exe_order); return ML_ERROR_NONE; @@ -74,12 +80,41 @@ int CacheLoader::loadAsync(unsigned int order, std::make_shared>(work, (void *)(std::uintptr_t)order); task->setTimeout(timeout_ms); - return task_executor->run(task, complete); + return load_task_executor->run(task, complete); +} + +int CacheLoader::flushAsync(unsigned int order, + TaskExecutor::CompleteCallback complete) { + return flushAsync(order, complete, LONG_MAX); +} + +int CacheLoader::flushAsync(unsigned int order, + TaskExecutor::CompleteCallback complete, + long timeout_ms) { + if (!unload_task_executor) { + ml_loge("init is needed"); + return ML_ERROR_INVALID_PARAMETER; + } + + Task::Work work = [&](std::atomic_bool &running, void *data) { + unsigned int exe_order = (unsigned int)(std::uintptr_t)data; + + // pool->flushExcept({exe_order - 1, exe_order}); + pool->flushExcept(exe_order); + + return ML_ERROR_NONE; + }; + + auto task = + std::make_shared>(work, (void *)(std::uintptr_t)order); + task->setTimeout(timeout_ms); + + return unload_task_executor->run(task, complete); } int CacheLoader::cancelAsync(int id) { try { - task_executor->cancel(id); + load_task_executor->cancel(id); } catch (const std::exception &e) { ml_loge("CacheLoader(%s): failed to cancel(%d): %s", pool->getName().c_str(), id, e.what()); diff --git a/nntrainer/tensor/cache_loader.h b/nntrainer/tensor/cache_loader.h index 155214c0d1..0b0dbbe98d 100644 --- a/nntrainer/tensor/cache_loader.h +++ b/nntrainer/tensor/cache_loader.h @@ -84,6 +84,29 @@ class CacheLoader { TaskExecutor::CompleteCallback callback, long timeout_ms); + /** + * @brief Load cache data asynchronously with execution order + * + * @param order execution order + * @param complete complete callback + * @return async task id + */ + virtual int flushAsync(unsigned int order, + TaskExecutor::CompleteCallback callback); + + /** + * @brief Load cache data asynchronously with execution order + * + * @param order execution order + * @param complete complete callback + * @param timeout timeout time in ms + * @return async task id + * @note timeout_ms does not work now. + */ + virtual int flushAsync(unsigned int order, + TaskExecutor::CompleteCallback callback, + long timeout_ms); + /** * @brief Cancel async task * @@ -94,7 +117,8 @@ class CacheLoader { private: std::shared_ptr pool; /**< cache pool */ - TaskExecutor *task_executor; /**< task executor */ + TaskExecutor *load_task_executor; /**< task executor */ + TaskExecutor *unload_task_executor; /**< task executor */ }; } // namespace nntrainer diff --git a/nntrainer/tensor/cache_pool.cpp b/nntrainer/tensor/cache_pool.cpp index 8f2d4e1f2b..519dda1b6b 100644 --- a/nntrainer/tensor/cache_pool.cpp +++ b/nntrainer/tensor/cache_pool.cpp @@ -215,7 +215,7 @@ void CachePool::flushExcept(unsigned int order) { auto id = elem->getId(); auto exe_order = exe_orders.at(id - 1); auto found = std::find(exe_order.begin(), exe_order.end(), order); - if (found == exe_order.end()) { + if (found != exe_order.end()) { /** * We assumes that flushExcept will be called in front of each execution * order, and the order is incremental. So, we can conclude that, if the diff --git a/nntrainer/tensor/manager.cpp b/nntrainer/tensor/manager.cpp index e454e51119..8500db83d3 100644 --- a/nntrainer/tensor/manager.cpp +++ b/nntrainer/tensor/manager.cpp @@ -145,6 +145,7 @@ void Manager::reinitialize() { } void Manager::allocateWeights(unsigned int max_exec_order_, bool init) { + max_exec_order = max_exec_order_; if (!weight_pool.isAllocated()) { finalizeTensorPool(weight_pool, 0, max_exec_order_); weight_pool.allocate(init); @@ -375,11 +376,16 @@ std::vector Manager::requestWeights( * However, current implementation of loss needs the gradient computation. * and therefore, if we remove the calcDerivative order, then tests fails. */ - - TensorLifespan var_ls = - (enable_swap && (exec_mode == ExecutionMode::INFERENCE)) - ? TensorLifespan::FORWARD_INFER_LIFESPAN - : TensorLifespan::MAX_LIFESPAN; + TensorLifespan var_ls; + if (exec_mode != ExecutionMode::INFERENCE) { + var_ls = TensorLifespan::MAX_LIFESPAN; + } else { + if (enable_swap) { + var_ls = TensorLifespan::FORWARD_INFER_LIFESPAN; + } else { + var_ls = TensorLifespan::FORWARD_FUNC_LIFESPAN; + } + } TensorLifespan grad_ls = TensorLifespan::BACKWARD_FUNC_LIFESPAN; @@ -394,6 +400,8 @@ std::vector Manager::requestWeights( std::vector var_exec_order; for (auto order : default_var_exec_order) { var_exec_order.push_back(order); + if (exec_mode == ExecutionMode::INFERENCE) + break; } // auto var_exec_order = default_var_exec_order; std::vector grad_exec_order; @@ -730,6 +738,116 @@ void Manager::flushCache() { } } +bool Manager::checkLoadComplete(unsigned int order) { + if (async_load_tensor.count(order) == 1) { + auto &tasks = async_load_tensor[order]; + std::unique_lock lock(completed_load_mutex); + if (exec_mode == ExecutionMode::TRAIN) { + auto w_fut = completed_load_tensor[std::get<0>(tasks)].get_future(); + auto t_fut = completed_load_tensor[std::get<1>(tasks)].get_future(); + lock.unlock(); + if (std::get<0>(tasks) != 0) + w_fut.wait(); + if (std::get<1>(tasks) != 0) + t_fut.wait(); + } else { + auto w_fut = completed_load_tensor[std::get<0>(tasks)].get_future(); + lock.unlock(); + if (std::get<0>(tasks) != 0) + w_fut.wait(); + } + async_load_tensor.erase(order); + ml_logd("wait and completed %d", order); + } else { + ml_logd("without wait completed %d", order); + } + return true; +} + +bool Manager::checkUnloadComplete(unsigned int order) { + if (async_unload_tensor.count(order)) { + auto &tasks = async_unload_tensor[order]; + std::unique_lock lock(completed_unload_mutex); + if (exec_mode == ExecutionMode::TRAIN) { + auto w_fut = completed_unload_tensor[std::get<0>(tasks)].get_future(); + auto t_fut = completed_unload_tensor[std::get<1>(tasks)].get_future(); + lock.unlock(); + if (std::get<0>(tasks) != 0) + w_fut.wait(); + if (std::get<1>(tasks) != 0) + t_fut.wait(); + } else { + auto w_fut = completed_unload_tensor[std::get<0>(tasks)].get_future(); + lock.unlock(); + if (std::get<0>(tasks) != 0) + w_fut.wait(); + } + async_unload_tensor.erase(order); + } + return true; +} + +void Manager::LoadTensors(unsigned int order) { + auto loadTensorsAsync = [&](TensorPool &pool, unsigned int order) { + return pool.loadCacheExecAsync( + order, [&](int id, TaskExecutor::CompleteStatus status) { + std::scoped_lock lock(completed_load_mutex); + completed_load_tensor[id].set_value(true); + }); + }; + + auto enqueTasks = [&](unsigned int o) { + if (async_load_tensor.count(o)) { + ml_logd("Task loadTensors (%d) is in progress", o); + return; + } + auto load_weight = loadTensorsAsync(weight_pool, o); + ml_logd("load weigth is requested in LoadTensors with order - %d", o); + int load_tensor = 0; + if (exec_mode != ml::train::ExecutionMode::INFERENCE) { + load_tensor = loadTensorsAsync(tensor_pool, o); + ml_logd("load tensor is requested in LoadTensors with order - %d", o); + } + NNTR_THROW_IF(load_weight < 0 || load_tensor < 0, std::runtime_error) + << "Fail to launch task"; + async_load_tensor[o] = std::make_tuple(load_weight, load_tensor); + }; + + for (unsigned int i = order; i < order + swap_lookahead + 1; ++i) { + if (i <= max_exec_order) + enqueTasks(i); + } +} + +void Manager::UnloadTensors(unsigned int order) { + auto unloadTensorsAsync = [&](TensorPool &pool, unsigned int order) { + return pool.flushCacheExecAsync( + order, [&](int id, TaskExecutor::CompleteStatus status) { + std::scoped_lock lock(completed_unload_mutex); + completed_unload_tensor[id].set_value(true); + }); + }; + + auto enqueTasks = [&](unsigned int o) { + if (async_unload_tensor.count(o)) { + ml_logd("Task unloadTensors (%d) is in progress", o); + return; + } + auto unload_weight = unloadTensorsAsync(weight_pool, o); + ml_logd("unload weight is requested in UnLoadTensors with order - %d", o); + int unload_tensor = 0; + if (exec_mode != ml::train::ExecutionMode::INFERENCE) { + unload_tensor = unloadTensorsAsync(tensor_pool, o); + ml_logd("unload tensor is requested in UnLoadTensors with order - %d", o); + } + NNTR_THROW_IF(unload_weight < 0 || unload_tensor < 0, std::runtime_error) + << "Faile to launch task"; + async_unload_tensor[o] = std::make_tuple(unload_weight, unload_tensor); + }; + + enqueTasks(order); +} + void Manager::flushCacheExcept(unsigned int order) { auto loadAsync = [&](TensorPool &pool, unsigned int order) { return pool.loadCacheExecAsync( diff --git a/nntrainer/tensor/manager.h b/nntrainer/tensor/manager.h index 281b6ebe4e..1e2308efb3 100644 --- a/nntrainer/tensor/manager.h +++ b/nntrainer/tensor/manager.h @@ -132,6 +132,7 @@ class Manager { * @brief Constructor of Manager */ Manager() : + enable_swap(false), enable_optimizations(true), swap_lookahead(0), tensor_format("NCHW"), @@ -148,12 +149,12 @@ class Manager { weight_pool(enable_swap_, swap_path, "weight_pool"), tensor_pool(enable_swap_ && (exec_mode_ == ExecutionMode::TRAIN), swap_path, "tensor_pool"), + enable_swap(enable_swap_), enable_optimizations(true), swap_lookahead(lookahead), tensor_format(tensor_format_), tensor_dtype(split(tensor_dtype_, getRegex("\\-"))), - exec_mode(exec_mode_), - enable_swap(enable_swap_) {} + exec_mode(exec_mode_) {} /** * @brief Construct a new Manager object (deleted) @@ -486,6 +487,39 @@ class Manager { */ void flushCacheExcept(unsigned int order); + /** + * @brief load cache data for the execution order + * + * @param order execution order + * @note preloading loads execution order data asynchronously, + * for lookahead size. + */ + void LoadTensors(unsigned int order); + + /** + * @brief check completion of load data for the execution order + * + * @param order execution order + * @note preloading tensors for execution order. + */ + bool checkLoadComplete(unsigned int order); + + /** + * @brief check completion of unload data for the execution order + * + * @param order execution order + * @note preloading tensors for execution order. + */ + bool checkUnloadComplete(unsigned int order); + + /** + * @brief flush load data for the execution order + * + * @param order execution order + * @note flush tensors for execution order. + */ + void UnloadTensors(unsigned int order); + /** * @brief reinitialize manager */ @@ -520,14 +554,41 @@ class Manager { TensorPool weight_pool; /**< tensor pool to request tensors */ TensorPool tensor_pool; /**< tensor pool to request tensors */ + /** async load task */ + std::map async_task_weight_load; + + /** async unload task */ + std::map async_task_weight_unload; + + /**< async tasks > + */ + std::map> async_task_eos; /**< async tasks > */ + std::map> async_load_tensor; + + std::map complete_load_tensor; + + std::map> async_unload_tensor; + std::map> completed; + + std::map> completed_load_tensor; + + std::map> completed_unload_tensor; + /**< async tasks completion */ std::mutex completed_mutex; /**< mutex for async tasks completion */ + std::mutex completed_load_mutex; /**< mutex for async tasks completion */ + + std::mutex completed_unload_mutex; /**< mutex for async tasks completion */ + + bool enable_swap; /**< to enable swap */ + bool enable_optimizations; /**< to enable memory optimizations */ unsigned int swap_lookahead; /** lookahead for memory swap */ @@ -538,7 +599,7 @@ class Manager { ExecutionMode exec_mode; - bool enable_swap; + unsigned int max_exec_order; /** * @brief Finalize the given tensor pool diff --git a/nntrainer/tensor/task_executor.cpp b/nntrainer/tensor/task_executor.cpp index 0f68e4c8e3..404f4baf85 100644 --- a/nntrainer/tensor/task_executor.cpp +++ b/nntrainer/tensor/task_executor.cpp @@ -30,17 +30,16 @@ namespace nntrainer { std::atomic_int32_t TaskExecutor::ids(1); TaskExecutor::TaskExecutor(const std::string &n) : - name(n), run_thread(true), wait_complete(false) { + name(n), run_thread(true), wait_complete(false), stop_all(false) { task_thread = std::thread([&]() { ml_logd("Task Thread(%s): start thread", name.c_str()); while (run_thread) { - std::unique_lock lk(task_mutex); - if (!task_cv.wait_for(lk, std::chrono::milliseconds(500), - [&] { return !task_queue.empty(); })) - continue; + std::unique_lock lk(task_mutex); + task_cv.wait(lk, [&] { return !task_queue.empty() || stop_all; }); + if (stop_all && task_queue.empty()) + return; auto &task_info = task_queue.front(); - lk.unlock(); const auto &id = std::get(task_info); const auto &callback = std::get(task_info); @@ -48,7 +47,6 @@ TaskExecutor::TaskExecutor(const std::string &n) : auto status = worker(task_info); callback(id, status); - lk.lock(); task_queue.pop_front(); lk.unlock(); } @@ -58,7 +56,8 @@ TaskExecutor::TaskExecutor(const std::string &n) : TaskExecutor::~TaskExecutor() { run_thread = false; - + stop_all = true; + task_cv.notify_all(); task_thread.join(); } diff --git a/nntrainer/tensor/task_executor.h b/nntrainer/tensor/task_executor.h index 4fc20c9c1b..35f9fd9c14 100644 --- a/nntrainer/tensor/task_executor.h +++ b/nntrainer/tensor/task_executor.h @@ -153,6 +153,7 @@ class TaskExecutor { std::string name; bool run_thread; bool wait_complete; + bool stop_all; std::list> task_queue; diff --git a/nntrainer/tensor/tensor_pool.cpp b/nntrainer/tensor/tensor_pool.cpp index 27f22d8a0c..d9ee957aed 100644 --- a/nntrainer/tensor/tensor_pool.cpp +++ b/nntrainer/tensor/tensor_pool.cpp @@ -435,6 +435,8 @@ bool TensorPool::isTensorLongTerm(const TensorLifespan &lifespan) { switch (lifespan) { case TensorLifespan::EPOCH_LIFESPAN: [[fallthrough]]; + case TensorLifespan::FORWARD_INFER_LIFESPAN: + [[fallthrough]]; case TensorLifespan::MAX_LIFESPAN: return true; case TensorLifespan::FORWARD_FUNC_LIFESPAN: @@ -470,7 +472,15 @@ int TensorPool::loadCacheExecAsync( if (dynamic_cast(mem_pool.get())) return cache_loader->loadAsync(order, complete_callback); else - return -1; + return 0; +} + +int TensorPool::flushCacheExecAsync( + unsigned int order, TaskExecutor::CompleteCallback complete_callback) { + if (dynamic_cast(mem_pool.get())) + return cache_loader->flushAsync(order, complete_callback); + else + return 0; } void TensorPool::loadCacheCancel(int id) { diff --git a/nntrainer/tensor/tensor_pool.h b/nntrainer/tensor/tensor_pool.h index 1d2addb52e..45c79c17b5 100644 --- a/nntrainer/tensor/tensor_pool.h +++ b/nntrainer/tensor/tensor_pool.h @@ -288,6 +288,15 @@ class TensorPool { int loadCacheExecAsync(unsigned int order, TaskExecutor::CompleteCallback complete_callback); + /** + * @brief load cache data by execution order + * + * @param order execution order + * @return async task id + */ + int flushCacheExecAsync(unsigned int order, + TaskExecutor::CompleteCallback complete_callback); + /** * @brief load cache data by execution order *