diff --git a/Applications/SimpleFC/README.md b/Applications/SimpleFC/README.md deleted file mode 100644 index f195a8c764..0000000000 --- a/Applications/SimpleFC/README.md +++ /dev/null @@ -1,30 +0,0 @@ -# Resnet with cifar100 - -This application contains a Resnet18 model and a trainer with cifar100. - -Reference. [Kaiming He. 2015](https://arxiv.org/abs/1512.03385) -Reference. [Learning Multiple Layers of Features from Tiny Images, Alex Krizhevsky, 2009](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf) - -## How to run a train epoch - -### To simply run with a fake data. - -Once you compile, with `meson`, you can run with `meson test app_resnet18`. -Please file an issue if you have a problem running the example. - -```bash -$ meson ${build_dir} -Denable-test=true -Denable-long-test=true -$ meson test app_resnet18 -v -C ${build_dir} -``` - -### To run with a real data. - -```bash -$ meson ${build_dir} build -$ ${project_dir}/Applications/Resnet/res/prepare_dataset.sh ${dataset_download_dir} # this is to download raw data of cifar100 -$ OPENBLAS_NUM_THREADS=1 ${build_dir}/Applications/Resnet/jni/nntrainer_resnet18 \ - ${dataset_download_dir}/cifar-100-binary \ - ${batch_size} \ - ${data_split} \ - ${epoch} -``` diff --git a/Applications/SimpleFC/jni/main.cpp b/Applications/SimpleFC/jni/main.cpp index ca29eb6a46..6ae0b4e4cb 100644 --- a/Applications/SimpleFC/jni/main.cpp +++ b/Applications/SimpleFC/jni/main.cpp @@ -1,13 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 /** - * Copyright (C) 2020 Jihoon Lee + * Copyright (C) 2024 Jijoong Moon * * @file main.cpp - * @date 24 Jun 2021 - * @todo move resnet model creating to separate sourcefile - * @brief task runner for the resnet + * @date 10 Dec 2024 + * @brief Test Application for Asynch FSU * @see https://github.com/nnstreamer/nntrainer - * @author Jihoon Lee + * @author Jijoong Moon * @bug No known bugs except for NYI items */ #include @@ -76,9 +75,9 @@ static std::string withKey(const std::string &key, } /** - * @brief Create resnet 18 + * @brief Create network * - * @return vector of layers that contain full graph of resnet18 + * @return vector of layers that contain full graph of asynch */ std::vector createGraph() { using ml::train::createLayer; @@ -86,22 +85,22 @@ std::vector createGraph() { std::vector layers; layers.push_back(createLayer( - "input", {withKey("name", "input0"), withKey("input_shape", "1:1:32")})); + "input", {withKey("name", "input0"), withKey("input_shape", "1:1:320")})); - layers.push_back( - createLayer("fully_connected", - {withKey("unit", 10)})); + layers.push_back(createLayer("fully_connected", + {withKey("unit", 100), + withKey("weight_initializer", "xavier_uniform"), + withKey("bias_initializer", "zeros")})); - layers.push_back( - createLayer("fully_connected", - {withKey("unit", 10)})); + layers.push_back(createLayer("fully_connected", + {withKey("unit", 100), + withKey("weight_initializer", "xavier_uniform"), + withKey("bias_initializer", "zeros")})); return layers; } -/// @todo update createResnet18 to be more generic ModelHandle create() { - /// @todo support "LOSS : cross" for TF_Lite Exporter ModelHandle model = ml::train::createModel(ml::train::ModelType::NEURAL_NET, {withKey("loss", "mse")}); @@ -126,17 +125,19 @@ int validData_cb(float **input, float **label, bool *last, void *user_data) { return 0; } -/// @todo maybe make num_class also a parameter void createAndRun(unsigned int epochs, unsigned int batch_size, UserDataType &train_user_data, UserDataType &valid_user_data) { // setup model ModelHandle model = create(); - model->setProperty({withKey("batch_size", batch_size), - withKey("epochs", epochs), - withKey("save_path", "model_full.bin"), - withKey("memory_swap","true")}); + model->setProperty( + {withKey("batch_size", batch_size), withKey("epochs", epochs), + // withKey("save_path", "model_full.bin")}); + // withKey("save_path", "model_full.bin"), withKey("memory_swap", + // "true")}); + withKey("memory_swap", "true"), withKey("memory_swap_lookahead", "1"), + withKey("model_tensor_type", "FP16-FP16")}); auto optimizer = ml::train::createOptimizer("sgd", {"learning_rate=0.001"}); model->setOptimizer(std::move(optimizer)); @@ -156,28 +157,24 @@ void createAndRun(unsigned int epochs, unsigned int batch_size, auto dataset_valid = ml::train::createDataset( ml::train::DatasetType::GENERATOR, validData_cb, valid_user_data.get()); - // model->setDataset(ml::train::DatasetModeType::MODE_TRAIN, - // std::move(dataset_train)); - // model->setDataset(ml::train::DatasetModeType::MODE_VALID, - // std::move(dataset_valid)); - - // if (transfer_learning) - // model->load(pretrained_bin_path); - // model->train(); + // to test asynch fsu, we do need save the model weight data in file + model->save("simplefc_weight_fp16_fp16_100.bin", + ml::train::ModelFormat::MODEL_FORMAT_BIN); + model->load("./simplefc_weight_fp16_fp16_100.bin"); model->summarize(std::cout, ML_TRAIN_SUMMARY_MODEL); - uint feature_size = 32; + uint feature_size = 320; - float input [32]; - float label [1]; + float input[320]; + float label[1]; - for(uint j=0;j in; - std::vector l; - std::vector answer; + std::vector in; + std::vector l; + std::vector answer; in.push_back(input); l.push_back(label); @@ -187,8 +184,7 @@ void createAndRun(unsigned int epochs, unsigned int batch_size, in.clear(); l.clear(); - std::cout << "done"< @@ -196,10 +192,10 @@ createFakeDataGenerator(unsigned int batch_size, unsigned int simulated_data_size, unsigned int data_split) { UserDataType train_data(new nntrainer::util::RandomDataLoader( - {{batch_size, 1, 1, 32}}, {{batch_size, 1, 1, 10}}, + {{batch_size, 1, 1, 320}}, {{batch_size, 1, 1, 100}}, simulated_data_size / data_split)); UserDataType valid_data(new nntrainer::util::RandomDataLoader( - {{batch_size, 1, 1, 32}}, {{batch_size, 1, 1, 10}}, + {{batch_size, 1, 1, 320}}, {{batch_size, 1, 1, 100}}, simulated_data_size / data_split)); return {std::move(train_data), std::move(valid_data)}; @@ -231,9 +227,9 @@ int main(int argc, char *argv[]) { std::string data_dir = "fake"; uint batch_size = 1; - uint data_split =1; + uint data_split = 1; uint epoch = 1; - + std::array user_datas; try { diff --git a/Applications/SimpleFC/jni/meson.build b/Applications/SimpleFC/jni/meson.build index 2033e858c8..6743f03809 100644 --- a/Applications/SimpleFC/jni/meson.build +++ b/Applications/SimpleFC/jni/meson.build @@ -1,22 +1,22 @@ -resnet_sources = [ +app_sources = [ 'main.cpp', cifar_path / 'cifar_dataloader.cpp' ] -resnet_dependencies = [app_utils_dep, +app_dependencies = [app_utils_dep, iniparser_dep, nntrainer_dep, nntrainer_ccapi_dep ] if get_option('enable-test') - resnet_dependencies += [gtest_dep] + app_dependencies += [gtest_dep] endif e = executable('nntrainer_simplefc', - resnet_sources, + app_sources, include_directories: [include_directories('.'), cifar_include_dir], - dependencies: resnet_dependencies, + dependencies: app_dependencies, install: get_option('install-app'), install_dir: application_install_dir ) @@ -24,5 +24,5 @@ e = executable('nntrainer_simplefc', if get_option('enable-long-test') testenv = environment() testenv.set('OPENBLAS_NUM_THREADS', '4') - test('app_resnet18', e, args: ['fake', '1', '128', '1'], env: testenv, timeout: 300) + test('app_asynch_fsu', e, args: ['fake', '1', '128', '1'], env: testenv, timeout: 300) endif diff --git a/meson.build b/meson.build index 150be5c4ac..815a47f576 100644 --- a/meson.build +++ b/meson.build @@ -90,7 +90,6 @@ if get_option('enable-fp16') extra_defines += '-DENABLE_FP16=1' extra_defines += '-DUSE__FP16=1' extra_defines += '-DUSE_NEON=1' - extra_defines += '-DUSE_MMAP=1' elif arch == 'aarch64' ## About FP16 in GCC (from GCC-9.1 manual) # https://gcc.gnu.org/onlinedocs/gcc-9.1.0/gcc/Half-Precision.html diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 479d7296da..9d568c324b 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -229,6 +229,14 @@ int NetworkGraph::checkCompiledGraph() { void NetworkGraph::markNodesForBackwarding() { /** accumulate all the nodes which must support backwarding */ std::unordered_set must_support_backwarding; + if (exec_mode == ExecutionMode::INFERENCE) { + for (auto iter = cbegin(); iter != cend(); iter++) { + auto lnode = (*iter); + lnode->needsCalcGradient(false); + lnode->needsCalcDerivative(false); + } + return; + } /** * if a node is trainable, then all the nodes ahead of it must support @@ -867,14 +875,16 @@ NetworkGraph::finalizeContext(const std::shared_ptr &lnode, } lnode->setDataType(init_context.getWeightDataType(), init_context.getActivationDataType()); - + bool trainable = lnode->getTrainable(); + if (exec_mode == ExecutionMode::INFERENCE) + trainable = false; lnode->configureRunContext( // TODO: update weights spec for trainable based on layer trainable prop tensor_manager->requestWeights(gnode, init_context.getWeightsSpec(), - lnode->getTrainable(), shared_weight_names), + trainable, shared_weight_names), inputs, outputs, tensor_manager->requestTensors(gnode, init_context.getTensorsSpec(), - lnode->getTrainable(), shared_tensor_names), + trainable, shared_tensor_names), init_context.getLossScale()); return outputs; @@ -1552,6 +1562,22 @@ void NetworkGraph::flushCacheExcept(unsigned int order) { tensor_manager->flushCacheExcept(order); } +void NetworkGraph::LoadTensors(unsigned int order) { + tensor_manager->LoadTensors(order); +} + +bool NetworkGraph::checkLoadComplete(unsigned int order) { + return tensor_manager->checkLoadComplete(order); +} + +bool NetworkGraph::checkUnloadComplete(unsigned int order) { + return tensor_manager->checkUnloadComplete(order); +} + +void NetworkGraph::UnloadTensors(unsigned int order) { + tensor_manager->UnloadTensors(order); +} + void NetworkGraph::requestOptimizerVariable( std::function(const TensorDim &)> cb, bool request_only_trainable) { diff --git a/nntrainer/graph/network_graph.h b/nntrainer/graph/network_graph.h index 05aeae9193..078268d7b2 100644 --- a/nntrainer/graph/network_graph.h +++ b/nntrainer/graph/network_graph.h @@ -370,8 +370,12 @@ class NetworkGraph { * @brief Allocate memory for all the managed weights */ void allocateWeights(bool init = true) { - tensor_manager->allocateWeights( - std::get<3>(backward_iter_end->getExecutionOrder()), init); + unsigned int max_exec_order = + std::get<3>(backward_iter_end->getExecutionOrder()); + + if (exec_mode == ExecutionMode::INFERENCE) + max_exec_order = std::get<0>(forward_iter_end->getExecutionOrder()); + tensor_manager->allocateWeights(max_exec_order, init); } /** @@ -447,6 +451,34 @@ class NetworkGraph { */ void flushCacheExcept(const unsigned int order); + /** + * @brief Load data of order to the device + * + * @param order execution order + */ + void LoadTensors(const unsigned int order); + + /** + * @brief check data of order is loaded + * + * @param order execution order + */ + bool checkLoadComplete(const unsigned int order); + + /** + * @brief check data of order is Unloaded + * + * @param order execution order + */ + bool checkUnloadComplete(const unsigned int order); + + /** + * @brief Load data of order to the device + * + * @param order execution order + */ + void UnloadTensors(const unsigned int order); + #ifdef ENABLE_TEST /** * @brief Get layer node's tenexecution orders diff --git a/nntrainer/layers/layer_node.cpp b/nntrainer/layers/layer_node.cpp index eed9398094..9c6c290703 100644 --- a/nntrainer/layers/layer_node.cpp +++ b/nntrainer/layers/layer_node.cpp @@ -489,7 +489,8 @@ void LayerNode::read(std::ifstream &file, bool opt_var, for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) { /// @note shared weights are only be read at the first acecss - if (run_context->isGradientLastAccess(i)) { + // if (run_context->isGradientLastAccess(i)) { + if (run_context->isGradientFirstAccess(i)) { if (layer->getType() == BatchNormalizationLayer::type) { if ((mode == ml::train::ExecutionMode::TRAIN) && (this->getWeightDataType() != TensorDim::DataType::FP32)) { @@ -526,7 +527,7 @@ void LayerNode::save(std::ofstream &file, bool opt_var, if (opt_var) { for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) { - if (run_context->isGradientLastAccess(i) && getTrainable()) { + if (run_context->isGradientFirstAccess(i) && getTrainable()) { // @note save optimizer variables if (run_context->weightHasGradient(i)) { for (unsigned int j = 0; j < run_context->getNumWeightOptVar(i); @@ -539,7 +540,7 @@ void LayerNode::save(std::ofstream &file, bool opt_var, } else { // @note shared weights are only be saved at the first access for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) { - if (run_context->isGradientLastAccess(i)) { + if (run_context->isGradientFirstAccess(i)) { /** @note For batch normalization layer, we do need full precision for * training and the data type of weight is full precision. But for diff --git a/nntrainer/models/model_common_properties.cpp b/nntrainer/models/model_common_properties.cpp index b9589c67f6..9fc78d5602 100644 --- a/nntrainer/models/model_common_properties.cpp +++ b/nntrainer/models/model_common_properties.cpp @@ -34,8 +34,6 @@ MemorySwap::MemorySwap(bool value) { set(value); } MemorySwapPath::MemorySwapPath(const std::string &value) { set(value); } -MemorySwapMode::MemorySwapMode(const std::string &value) { set(value); } - MemorySwapLookahead::MemorySwapLookahead(const unsigned int &value) { set(value); } diff --git a/nntrainer/models/model_common_properties.h b/nntrainer/models/model_common_properties.h index 8b5d539762..0fb5793b0d 100644 --- a/nntrainer/models/model_common_properties.h +++ b/nntrainer/models/model_common_properties.h @@ -179,24 +179,6 @@ class MemorySwapLookahead : public Property { MemorySwapLookahead(const unsigned int &value = 0); }; -/** - * @brief cache file path property - * - */ -class MemorySwapMode : public Property { -public: - static constexpr const char *key = - "memory_swap_mode"; /**< unique key to access */ - using prop_tag = str_prop_tag; /**< property type */ - - /** - * @brief Constructor - * - * @param value value to set, defaults to current directory - */ - MemorySwapMode(const std::string &value = "train"); -}; - /** * @brief Enumeration of Data Type for model & layer */ diff --git a/nntrainer/models/neuralnet.cpp b/nntrainer/models/neuralnet.cpp index ce88593a70..fc75d647c3 100644 --- a/nntrainer/models/neuralnet.cpp +++ b/nntrainer/models/neuralnet.cpp @@ -67,12 +67,11 @@ namespace nntrainer { NeuralNetwork::NeuralNetwork() : model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm(), props::LossScale()), - model_flex_props(props::Epochs(), props::TrainingBatchSize(), - props::SavePath(), props::ContinueTrain(), - props::SaveBestPath(), props::MemoryOptimization(), - props::MemorySwap(), props::MemorySwapPath(), - props::MemorySwapLookahead(), props::TensorFormat(), - props::ModelTensorDataType(), props::MemorySwapMode()), + model_flex_props( + props::Epochs(), props::TrainingBatchSize(), props::SavePath(), + props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(), + props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead(), + props::TensorFormat(), props::ModelTensorDataType()), load_path(std::string()), epoch_idx(0), iter(0), @@ -88,12 +87,11 @@ NeuralNetwork::NeuralNetwork() : NeuralNetwork::NeuralNetwork(AppContext app_context_) : model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm(), props::LossScale()), - model_flex_props(props::Epochs(), props::TrainingBatchSize(), - props::SavePath(), props::ContinueTrain(), - props::SaveBestPath(), props::MemoryOptimization(), - props::MemorySwap(), props::MemorySwapPath(), - props::MemorySwapLookahead(), props::TensorFormat(), - props::ModelTensorDataType(), props::MemorySwapMode()), + model_flex_props( + props::Epochs(), props::TrainingBatchSize(), props::SavePath(), + props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(), + props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead(), + props::TensorFormat(), props::ModelTensorDataType()), load_path(std::string()), epoch_idx(0), iter(0), @@ -102,6 +100,7 @@ NeuralNetwork::NeuralNetwork(AppContext app_context_) : initialized(false), compiled(false), loadedFromConfig(false), + exec_mode(ExecutionMode::TRAIN), app_context(app_context_) {} int NeuralNetwork::loadFromConfig(const std::string &config) { @@ -214,6 +213,19 @@ int NeuralNetwork::compile(ExecutionMode mode) { int NeuralNetwork::initialize(ExecutionMode mode) { int status = ML_ERROR_NONE; + if (mode != exec_mode) { + if (mode == ExecutionMode::INFERENCE) { + ml_logd("Execution mode mismatch : train mode @compile & inference mode " + "@ initialize"); + exec_mode = mode; + } else { + NNTR_THROW_IF((exec_mode == ExecutionMode::TRAIN)), + std::invalid_argument) + << "Execution mode mismatch : trying to train with compiled for " + "infence"; + } + } + if (initialized) { ml_loge("Error: Initializing the model again"); return ML_ERROR_NOT_SUPPORTED; @@ -244,7 +256,7 @@ int NeuralNetwork::initialize(ExecutionMode mode) { } status = model_graph.initialize( - mode, input_conn, + exec_mode, input_conn, std::vector(label_layers.begin(), label_layers.end())); NN_RETURN_STATUS(); @@ -267,6 +279,8 @@ int NeuralNetwork::initialize(ExecutionMode mode) { // Allocate weights model_graph.allocateWeights(exec_mode != ExecutionMode::INFERENCE); + // enable this to save initialized weights for INFERENCE + // model_graph.allocateWeights(true); initialized = true; @@ -329,6 +343,7 @@ NeuralNetwork::~NeuralNetwork() { */ sharedConstTensors NeuralNetwork::forwarding( bool training, std::function stop_cb, void *userdata) { + std::function, bool)> forwarding_op = [this, stop_cb, userdata](std::shared_ptr node, bool training) -> void { @@ -336,9 +351,65 @@ sharedConstTensors NeuralNetwork::forwarding( PROFILE_MEM_ANNOTATE("Forwarding for layer: " + node->getName()); auto f = std::get<0>(node->getExecutionOrder()); - model_graph.flushCacheExcept(f); - node->forwarding(training); + // temperally remain. when we evaluate all for asynch mode, we weill remove + if (exec_mode == ExecutionMode::TRAIN) { + model_graph.flushCacheExcept(f); + node->forwarding(training); + } else { + + /** + currently, it supports FSU asynch mode for inference. The prcedure of + FSU is below, + + Prerequests : This function is called node by node at the forwarding + function in network graph. + + Step 1. If the execution order is the first (f==0) then, it will try to + load tensors which used at layer 0. + + Step 2. It check whether these tensors from Step 1, then do the + forwarding of the first node. + + Step 3. Then check the look a head which says how many layer weights need + to be loaded before running to hide overehad due to FSU, + + Step 4. Try to get the tesors by asking tensors for layers which is done + by thread pool + + Step 5. Try to release the weights which has execution order less then f. + + Step n. repeat next layer starting with checking the tenosrs are loaded, + and if it is loaded, then run forwarding. Every time it finishes + the forwarding, ask load tensors for next n layers. + **/ + + if (f == 0) + model_graph.LoadTensors(f); + + if (model_graph.checkLoadComplete(f)) { + node->forwarding(training); + ml_logd("Forwarding is done %d : %s", f, node->getName().c_str()); + + unsigned int lookahead = + std::get(model_flex_props); + + if (lookahead != 0) { + if ((f) % (lookahead + 1) == lookahead - 1) { + std::cout << "request load tensor : " << f + lookahead + 1 + << std::endl; + ml_logd("request load tensor for %d", f + 1); + model_graph.LoadTensors((f / (lookahead + 1) + 1) * + (lookahead + 1)); + } + } else { + model_graph.LoadTensors(f); + } + + if (f != 0) + model_graph.UnloadTensors(f); + } + } }; return model_graph.forwarding(training, forwarding_op, stop_cb, userdata); diff --git a/nntrainer/models/neuralnet.h b/nntrainer/models/neuralnet.h index d964c60ebd..b79378138c 100644 --- a/nntrainer/models/neuralnet.h +++ b/nntrainer/models/neuralnet.h @@ -629,7 +629,7 @@ s * @retval shared_ptr props::Epochs, props::TrainingBatchSize, props::SavePath, props::ContinueTrain, props::SaveBestPath, props::MemoryOptimization, props::MemorySwap, props::MemorySwapPath, props::MemorySwapLookahead, - props::TensorFormat, props::ModelTensorDataType, props::MemorySwapMode>; + props::TensorFormat, props::ModelTensorDataType>; using RigidPropTypes = std::tuple, std::vector, props::ClipGradByGlobalNorm, @@ -674,6 +674,8 @@ s * @retval shared_ptr RunStats training; /** training statistics of the model */ RunStats testing; /** testing statistics of the model */ + ExecutionMode exec_mode; /** execution mode : train : inference */ + AppContext app_context; /** Configurations bound to current app */ NetworkGraph model_graph; /** Network Model Graph */ @@ -682,8 +684,6 @@ s * @retval shared_ptr DynamicTrainingOptimization dynamic_training_opt; /**< Dynamic fine-tuning optimization mode. supported modes are "max" and "norm" */ - ExecutionMode exec_mode; - /** * @brief save model in ini * diff --git a/nntrainer/tensor/cache_elem.cpp b/nntrainer/tensor/cache_elem.cpp index c7849d70af..070b621307 100644 --- a/nntrainer/tensor/cache_elem.cpp +++ b/nntrainer/tensor/cache_elem.cpp @@ -60,7 +60,6 @@ void CacheElem::swapIn(Options opt) { mem_data->setAddr((void *)buf); mem_data->setValid(true); active = true; - #ifdef PROFILE std::string msg("CacheElem("); msg += device->getDevicePath() + ") #" + std::to_string(id); diff --git a/nntrainer/tensor/cache_loader.cpp b/nntrainer/tensor/cache_loader.cpp index b97094a8fa..64df838984 100644 --- a/nntrainer/tensor/cache_loader.cpp +++ b/nntrainer/tensor/cache_loader.cpp @@ -27,23 +27,29 @@ namespace nntrainer { CacheLoader::CacheLoader(std::shared_ptr cache_pool) : pool(cache_pool), - task_executor(nullptr) {} + load_task_executor(nullptr), + unload_task_executor(nullptr) {} CacheLoader::~CacheLoader() { - if (task_executor) - delete task_executor; + if (load_task_executor) + delete load_task_executor; + if (unload_task_executor) + delete unload_task_executor; } void CacheLoader::init() { - if (task_executor) - return; - task_executor = new TaskExecutor(pool->getName()); + if (load_task_executor == nullptr) + load_task_executor = new TaskExecutor(pool->getName()); + if (unload_task_executor == nullptr) + unload_task_executor = new TaskExecutor(pool->getName()); } void CacheLoader::finish() { - delete task_executor; - task_executor = nullptr; + delete load_task_executor; + load_task_executor = nullptr; + delete unload_task_executor; + unload_task_executor = nullptr; } void CacheLoader::load(unsigned int order) { pool->loadExec(order); } @@ -56,7 +62,7 @@ int CacheLoader::loadAsync(unsigned int order, int CacheLoader::loadAsync(unsigned int order, TaskExecutor::CompleteCallback complete, long timeout_ms) { - if (!task_executor) { + if (!load_task_executor) { ml_loge("init is needed"); return ML_ERROR_INVALID_PARAMETER; } @@ -64,7 +70,7 @@ int CacheLoader::loadAsync(unsigned int order, Task::Work work = [&](std::atomic_bool &running, void *data) { unsigned int exe_order = (unsigned int)(std::uintptr_t)data; - pool->flushExcept({exe_order - 1, exe_order}); + // pool->flushExcept({exe_order - 1, exe_order}); pool->loadExec(exe_order); return ML_ERROR_NONE; @@ -74,12 +80,41 @@ int CacheLoader::loadAsync(unsigned int order, std::make_shared>(work, (void *)(std::uintptr_t)order); task->setTimeout(timeout_ms); - return task_executor->run(task, complete); + return load_task_executor->run(task, complete); +} + +int CacheLoader::flushAsync(unsigned int order, + TaskExecutor::CompleteCallback complete) { + return flushAsync(order, complete, LONG_MAX); +} + +int CacheLoader::flushAsync(unsigned int order, + TaskExecutor::CompleteCallback complete, + long timeout_ms) { + if (!unload_task_executor) { + ml_loge("init is needed"); + return ML_ERROR_INVALID_PARAMETER; + } + + Task::Work work = [&](std::atomic_bool &running, void *data) { + unsigned int exe_order = (unsigned int)(std::uintptr_t)data; + + // pool->flushExcept({exe_order - 1, exe_order}); + pool->flushExcept(exe_order); + + return ML_ERROR_NONE; + }; + + auto task = + std::make_shared>(work, (void *)(std::uintptr_t)order); + task->setTimeout(timeout_ms); + + return unload_task_executor->run(task, complete); } int CacheLoader::cancelAsync(int id) { try { - task_executor->cancel(id); + load_task_executor->cancel(id); } catch (const std::exception &e) { ml_loge("CacheLoader(%s): failed to cancel(%d): %s", pool->getName().c_str(), id, e.what()); diff --git a/nntrainer/tensor/cache_loader.h b/nntrainer/tensor/cache_loader.h index 155214c0d1..0b0dbbe98d 100644 --- a/nntrainer/tensor/cache_loader.h +++ b/nntrainer/tensor/cache_loader.h @@ -84,6 +84,29 @@ class CacheLoader { TaskExecutor::CompleteCallback callback, long timeout_ms); + /** + * @brief Load cache data asynchronously with execution order + * + * @param order execution order + * @param complete complete callback + * @return async task id + */ + virtual int flushAsync(unsigned int order, + TaskExecutor::CompleteCallback callback); + + /** + * @brief Load cache data asynchronously with execution order + * + * @param order execution order + * @param complete complete callback + * @param timeout timeout time in ms + * @return async task id + * @note timeout_ms does not work now. + */ + virtual int flushAsync(unsigned int order, + TaskExecutor::CompleteCallback callback, + long timeout_ms); + /** * @brief Cancel async task * @@ -94,7 +117,8 @@ class CacheLoader { private: std::shared_ptr pool; /**< cache pool */ - TaskExecutor *task_executor; /**< task executor */ + TaskExecutor *load_task_executor; /**< task executor */ + TaskExecutor *unload_task_executor; /**< task executor */ }; } // namespace nntrainer diff --git a/nntrainer/tensor/cache_pool.cpp b/nntrainer/tensor/cache_pool.cpp index 8f2d4e1f2b..519dda1b6b 100644 --- a/nntrainer/tensor/cache_pool.cpp +++ b/nntrainer/tensor/cache_pool.cpp @@ -215,7 +215,7 @@ void CachePool::flushExcept(unsigned int order) { auto id = elem->getId(); auto exe_order = exe_orders.at(id - 1); auto found = std::find(exe_order.begin(), exe_order.end(), order); - if (found == exe_order.end()) { + if (found != exe_order.end()) { /** * We assumes that flushExcept will be called in front of each execution * order, and the order is incremental. So, we can conclude that, if the diff --git a/nntrainer/tensor/manager.cpp b/nntrainer/tensor/manager.cpp index e454e51119..8500db83d3 100644 --- a/nntrainer/tensor/manager.cpp +++ b/nntrainer/tensor/manager.cpp @@ -145,6 +145,7 @@ void Manager::reinitialize() { } void Manager::allocateWeights(unsigned int max_exec_order_, bool init) { + max_exec_order = max_exec_order_; if (!weight_pool.isAllocated()) { finalizeTensorPool(weight_pool, 0, max_exec_order_); weight_pool.allocate(init); @@ -375,11 +376,16 @@ std::vector Manager::requestWeights( * However, current implementation of loss needs the gradient computation. * and therefore, if we remove the calcDerivative order, then tests fails. */ - - TensorLifespan var_ls = - (enable_swap && (exec_mode == ExecutionMode::INFERENCE)) - ? TensorLifespan::FORWARD_INFER_LIFESPAN - : TensorLifespan::MAX_LIFESPAN; + TensorLifespan var_ls; + if (exec_mode != ExecutionMode::INFERENCE) { + var_ls = TensorLifespan::MAX_LIFESPAN; + } else { + if (enable_swap) { + var_ls = TensorLifespan::FORWARD_INFER_LIFESPAN; + } else { + var_ls = TensorLifespan::FORWARD_FUNC_LIFESPAN; + } + } TensorLifespan grad_ls = TensorLifespan::BACKWARD_FUNC_LIFESPAN; @@ -394,6 +400,8 @@ std::vector Manager::requestWeights( std::vector var_exec_order; for (auto order : default_var_exec_order) { var_exec_order.push_back(order); + if (exec_mode == ExecutionMode::INFERENCE) + break; } // auto var_exec_order = default_var_exec_order; std::vector grad_exec_order; @@ -730,6 +738,116 @@ void Manager::flushCache() { } } +bool Manager::checkLoadComplete(unsigned int order) { + if (async_load_tensor.count(order) == 1) { + auto &tasks = async_load_tensor[order]; + std::unique_lock lock(completed_load_mutex); + if (exec_mode == ExecutionMode::TRAIN) { + auto w_fut = completed_load_tensor[std::get<0>(tasks)].get_future(); + auto t_fut = completed_load_tensor[std::get<1>(tasks)].get_future(); + lock.unlock(); + if (std::get<0>(tasks) != 0) + w_fut.wait(); + if (std::get<1>(tasks) != 0) + t_fut.wait(); + } else { + auto w_fut = completed_load_tensor[std::get<0>(tasks)].get_future(); + lock.unlock(); + if (std::get<0>(tasks) != 0) + w_fut.wait(); + } + async_load_tensor.erase(order); + ml_logd("wait and completed %d", order); + } else { + ml_logd("without wait completed %d", order); + } + return true; +} + +bool Manager::checkUnloadComplete(unsigned int order) { + if (async_unload_tensor.count(order)) { + auto &tasks = async_unload_tensor[order]; + std::unique_lock lock(completed_unload_mutex); + if (exec_mode == ExecutionMode::TRAIN) { + auto w_fut = completed_unload_tensor[std::get<0>(tasks)].get_future(); + auto t_fut = completed_unload_tensor[std::get<1>(tasks)].get_future(); + lock.unlock(); + if (std::get<0>(tasks) != 0) + w_fut.wait(); + if (std::get<1>(tasks) != 0) + t_fut.wait(); + } else { + auto w_fut = completed_unload_tensor[std::get<0>(tasks)].get_future(); + lock.unlock(); + if (std::get<0>(tasks) != 0) + w_fut.wait(); + } + async_unload_tensor.erase(order); + } + return true; +} + +void Manager::LoadTensors(unsigned int order) { + auto loadTensorsAsync = [&](TensorPool &pool, unsigned int order) { + return pool.loadCacheExecAsync( + order, [&](int id, TaskExecutor::CompleteStatus status) { + std::scoped_lock lock(completed_load_mutex); + completed_load_tensor[id].set_value(true); + }); + }; + + auto enqueTasks = [&](unsigned int o) { + if (async_load_tensor.count(o)) { + ml_logd("Task loadTensors (%d) is in progress", o); + return; + } + auto load_weight = loadTensorsAsync(weight_pool, o); + ml_logd("load weigth is requested in LoadTensors with order - %d", o); + int load_tensor = 0; + if (exec_mode != ml::train::ExecutionMode::INFERENCE) { + load_tensor = loadTensorsAsync(tensor_pool, o); + ml_logd("load tensor is requested in LoadTensors with order - %d", o); + } + NNTR_THROW_IF(load_weight < 0 || load_tensor < 0, std::runtime_error) + << "Fail to launch task"; + async_load_tensor[o] = std::make_tuple(load_weight, load_tensor); + }; + + for (unsigned int i = order; i < order + swap_lookahead + 1; ++i) { + if (i <= max_exec_order) + enqueTasks(i); + } +} + +void Manager::UnloadTensors(unsigned int order) { + auto unloadTensorsAsync = [&](TensorPool &pool, unsigned int order) { + return pool.flushCacheExecAsync( + order, [&](int id, TaskExecutor::CompleteStatus status) { + std::scoped_lock lock(completed_unload_mutex); + completed_unload_tensor[id].set_value(true); + }); + }; + + auto enqueTasks = [&](unsigned int o) { + if (async_unload_tensor.count(o)) { + ml_logd("Task unloadTensors (%d) is in progress", o); + return; + } + auto unload_weight = unloadTensorsAsync(weight_pool, o); + ml_logd("unload weight is requested in UnLoadTensors with order - %d", o); + int unload_tensor = 0; + if (exec_mode != ml::train::ExecutionMode::INFERENCE) { + unload_tensor = unloadTensorsAsync(tensor_pool, o); + ml_logd("unload tensor is requested in UnLoadTensors with order - %d", o); + } + NNTR_THROW_IF(unload_weight < 0 || unload_tensor < 0, std::runtime_error) + << "Faile to launch task"; + async_unload_tensor[o] = std::make_tuple(unload_weight, unload_tensor); + }; + + enqueTasks(order); +} + void Manager::flushCacheExcept(unsigned int order) { auto loadAsync = [&](TensorPool &pool, unsigned int order) { return pool.loadCacheExecAsync( diff --git a/nntrainer/tensor/manager.h b/nntrainer/tensor/manager.h index 281b6ebe4e..1e2308efb3 100644 --- a/nntrainer/tensor/manager.h +++ b/nntrainer/tensor/manager.h @@ -132,6 +132,7 @@ class Manager { * @brief Constructor of Manager */ Manager() : + enable_swap(false), enable_optimizations(true), swap_lookahead(0), tensor_format("NCHW"), @@ -148,12 +149,12 @@ class Manager { weight_pool(enable_swap_, swap_path, "weight_pool"), tensor_pool(enable_swap_ && (exec_mode_ == ExecutionMode::TRAIN), swap_path, "tensor_pool"), + enable_swap(enable_swap_), enable_optimizations(true), swap_lookahead(lookahead), tensor_format(tensor_format_), tensor_dtype(split(tensor_dtype_, getRegex("\\-"))), - exec_mode(exec_mode_), - enable_swap(enable_swap_) {} + exec_mode(exec_mode_) {} /** * @brief Construct a new Manager object (deleted) @@ -486,6 +487,39 @@ class Manager { */ void flushCacheExcept(unsigned int order); + /** + * @brief load cache data for the execution order + * + * @param order execution order + * @note preloading loads execution order data asynchronously, + * for lookahead size. + */ + void LoadTensors(unsigned int order); + + /** + * @brief check completion of load data for the execution order + * + * @param order execution order + * @note preloading tensors for execution order. + */ + bool checkLoadComplete(unsigned int order); + + /** + * @brief check completion of unload data for the execution order + * + * @param order execution order + * @note preloading tensors for execution order. + */ + bool checkUnloadComplete(unsigned int order); + + /** + * @brief flush load data for the execution order + * + * @param order execution order + * @note flush tensors for execution order. + */ + void UnloadTensors(unsigned int order); + /** * @brief reinitialize manager */ @@ -520,14 +554,41 @@ class Manager { TensorPool weight_pool; /**< tensor pool to request tensors */ TensorPool tensor_pool; /**< tensor pool to request tensors */ + /** async load task */ + std::map async_task_weight_load; + + /** async unload task */ + std::map async_task_weight_unload; + + /**< async tasks > + */ + std::map> async_task_eos; /**< async tasks > */ + std::map> async_load_tensor; + + std::map complete_load_tensor; + + std::map> async_unload_tensor; + std::map> completed; + + std::map> completed_load_tensor; + + std::map> completed_unload_tensor; + /**< async tasks completion */ std::mutex completed_mutex; /**< mutex for async tasks completion */ + std::mutex completed_load_mutex; /**< mutex for async tasks completion */ + + std::mutex completed_unload_mutex; /**< mutex for async tasks completion */ + + bool enable_swap; /**< to enable swap */ + bool enable_optimizations; /**< to enable memory optimizations */ unsigned int swap_lookahead; /** lookahead for memory swap */ @@ -538,7 +599,7 @@ class Manager { ExecutionMode exec_mode; - bool enable_swap; + unsigned int max_exec_order; /** * @brief Finalize the given tensor pool diff --git a/nntrainer/tensor/task_executor.cpp b/nntrainer/tensor/task_executor.cpp index 0f68e4c8e3..404f4baf85 100644 --- a/nntrainer/tensor/task_executor.cpp +++ b/nntrainer/tensor/task_executor.cpp @@ -30,17 +30,16 @@ namespace nntrainer { std::atomic_int32_t TaskExecutor::ids(1); TaskExecutor::TaskExecutor(const std::string &n) : - name(n), run_thread(true), wait_complete(false) { + name(n), run_thread(true), wait_complete(false), stop_all(false) { task_thread = std::thread([&]() { ml_logd("Task Thread(%s): start thread", name.c_str()); while (run_thread) { - std::unique_lock lk(task_mutex); - if (!task_cv.wait_for(lk, std::chrono::milliseconds(500), - [&] { return !task_queue.empty(); })) - continue; + std::unique_lock lk(task_mutex); + task_cv.wait(lk, [&] { return !task_queue.empty() || stop_all; }); + if (stop_all && task_queue.empty()) + return; auto &task_info = task_queue.front(); - lk.unlock(); const auto &id = std::get(task_info); const auto &callback = std::get(task_info); @@ -48,7 +47,6 @@ TaskExecutor::TaskExecutor(const std::string &n) : auto status = worker(task_info); callback(id, status); - lk.lock(); task_queue.pop_front(); lk.unlock(); } @@ -58,7 +56,8 @@ TaskExecutor::TaskExecutor(const std::string &n) : TaskExecutor::~TaskExecutor() { run_thread = false; - + stop_all = true; + task_cv.notify_all(); task_thread.join(); } diff --git a/nntrainer/tensor/task_executor.h b/nntrainer/tensor/task_executor.h index 4fc20c9c1b..35f9fd9c14 100644 --- a/nntrainer/tensor/task_executor.h +++ b/nntrainer/tensor/task_executor.h @@ -153,6 +153,7 @@ class TaskExecutor { std::string name; bool run_thread; bool wait_complete; + bool stop_all; std::list> task_queue; diff --git a/nntrainer/tensor/tensor_pool.cpp b/nntrainer/tensor/tensor_pool.cpp index 27f22d8a0c..d9ee957aed 100644 --- a/nntrainer/tensor/tensor_pool.cpp +++ b/nntrainer/tensor/tensor_pool.cpp @@ -435,6 +435,8 @@ bool TensorPool::isTensorLongTerm(const TensorLifespan &lifespan) { switch (lifespan) { case TensorLifespan::EPOCH_LIFESPAN: [[fallthrough]]; + case TensorLifespan::FORWARD_INFER_LIFESPAN: + [[fallthrough]]; case TensorLifespan::MAX_LIFESPAN: return true; case TensorLifespan::FORWARD_FUNC_LIFESPAN: @@ -470,7 +472,15 @@ int TensorPool::loadCacheExecAsync( if (dynamic_cast(mem_pool.get())) return cache_loader->loadAsync(order, complete_callback); else - return -1; + return 0; +} + +int TensorPool::flushCacheExecAsync( + unsigned int order, TaskExecutor::CompleteCallback complete_callback) { + if (dynamic_cast(mem_pool.get())) + return cache_loader->flushAsync(order, complete_callback); + else + return 0; } void TensorPool::loadCacheCancel(int id) { diff --git a/nntrainer/tensor/tensor_pool.h b/nntrainer/tensor/tensor_pool.h index 1d2addb52e..45c79c17b5 100644 --- a/nntrainer/tensor/tensor_pool.h +++ b/nntrainer/tensor/tensor_pool.h @@ -288,6 +288,15 @@ class TensorPool { int loadCacheExecAsync(unsigned int order, TaskExecutor::CompleteCallback complete_callback); + /** + * @brief load cache data by execution order + * + * @param order execution order + * @return async task id + */ + int flushCacheExecAsync(unsigned int order, + TaskExecutor::CompleteCallback complete_callback); + /** * @brief load cache data by execution order *