Skip to content

Commit

Permalink
[ FSU ] Enabls Asynchronos FSU for forwarding
Browse files Browse the repository at this point in the history
This PR enables asynchronos mode for FSU (flash storage utilization)
for better performance.

It splits the load and unload tensors which make difficult to handle.
Also fix the inference execution order when it is in INFERENCE mode
and change the trainable option to false when it calls the request
weights and tensors.

Add the new function to load and unload tensors as well as check load
complete.

It also considers weight pool and tensor pool differenetly according
to the ExecutionMode. It is not use FSU mode for tensor pool for the
INFERENCE Mode.

Resolves:

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <[email protected]>
  • Loading branch information
jijoongmoon committed Dec 10, 2024
1 parent 9c6723f commit cd17a66
Show file tree
Hide file tree
Showing 21 changed files with 487 additions and 157 deletions.
30 changes: 0 additions & 30 deletions Applications/SimpleFC/README.md

This file was deleted.

80 changes: 38 additions & 42 deletions Applications/SimpleFC/jni/main.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
* Copyright (C) 2024 Jijoong Moon <jijoong.moon@samsung.com>
*
* @file main.cpp
* @date 24 Jun 2021
* @todo move resnet model creating to separate sourcefile
* @brief task runner for the resnet
* @date 10 Dec 2024
* @brief Test Application for Asynch FSU
* @see https://github.com/nnstreamer/nntrainer
* @author Jihoon Lee <jhoon.it.lee@samsung.com>
* @author Jijoong Moon <jijoong.moon@samsung.com>
* @bug No known bugs except for NYI items
*/
#include <array>
Expand Down Expand Up @@ -76,32 +75,32 @@ static std::string withKey(const std::string &key,
}

/**
* @brief Create resnet 18
* @brief Create network
*
* @return vector of layers that contain full graph of resnet18
* @return vector of layers that contain full graph of asynch
*/
std::vector<LayerHandle> createGraph() {
using ml::train::createLayer;

std::vector<LayerHandle> layers;

layers.push_back(createLayer(
"input", {withKey("name", "input0"), withKey("input_shape", "1:1:32")}));
"input", {withKey("name", "input0"), withKey("input_shape", "1:1:320")}));

layers.push_back(
createLayer("fully_connected",
{withKey("unit", 10)}));
layers.push_back(createLayer("fully_connected",
{withKey("unit", 100),
withKey("weight_initializer", "xavier_uniform"),
withKey("bias_initializer", "zeros")}));

layers.push_back(
createLayer("fully_connected",
{withKey("unit", 10)}));
layers.push_back(createLayer("fully_connected",
{withKey("unit", 100),
withKey("weight_initializer", "xavier_uniform"),
withKey("bias_initializer", "zeros")}));

return layers;
}

/// @todo update createResnet18 to be more generic
ModelHandle create() {
/// @todo support "LOSS : cross" for TF_Lite Exporter
ModelHandle model = ml::train::createModel(ml::train::ModelType::NEURAL_NET,
{withKey("loss", "mse")});

Expand All @@ -126,17 +125,19 @@ int validData_cb(float **input, float **label, bool *last, void *user_data) {
return 0;
}

/// @todo maybe make num_class also a parameter
void createAndRun(unsigned int epochs, unsigned int batch_size,
UserDataType &train_user_data,
UserDataType &valid_user_data) {

// setup model
ModelHandle model = create();
model->setProperty({withKey("batch_size", batch_size),
withKey("epochs", epochs),
withKey("save_path", "model_full.bin"),
withKey("memory_swap","true")});
model->setProperty(
{withKey("batch_size", batch_size), withKey("epochs", epochs),
// withKey("save_path", "model_full.bin")});
// withKey("save_path", "model_full.bin"), withKey("memory_swap",
// "true")});
withKey("memory_swap", "true"), withKey("memory_swap_lookahead", "1"),
withKey("model_tensor_type", "FP16-FP16")});

auto optimizer = ml::train::createOptimizer("sgd", {"learning_rate=0.001"});
model->setOptimizer(std::move(optimizer));
Expand All @@ -156,28 +157,24 @@ void createAndRun(unsigned int epochs, unsigned int batch_size,
auto dataset_valid = ml::train::createDataset(
ml::train::DatasetType::GENERATOR, validData_cb, valid_user_data.get());

// model->setDataset(ml::train::DatasetModeType::MODE_TRAIN,
// std::move(dataset_train));
// model->setDataset(ml::train::DatasetModeType::MODE_VALID,
// std::move(dataset_valid));

// if (transfer_learning)
// model->load(pretrained_bin_path);
// model->train();
// to test asynch fsu, we do need save the model weight data in file
model->save("simplefc_weight_fp16_fp16_100.bin",
ml::train::ModelFormat::MODEL_FORMAT_BIN);
model->load("./simplefc_weight_fp16_fp16_100.bin");

model->summarize(std::cout, ML_TRAIN_SUMMARY_MODEL);

uint feature_size = 32;
uint feature_size = 320;

float input [32];
float label [1];
float input[320];
float label[1];

for(uint j=0;j<feature_size;++j)
for (uint j = 0; j < feature_size; ++j)
input[j] = j;

std::vector<float*> in;
std::vector<float*> l;
std::vector<float*> answer;
std::vector<float *> in;
std::vector<float *> l;
std::vector<float *> answer;

in.push_back(input);
l.push_back(label);
Expand All @@ -187,19 +184,18 @@ void createAndRun(unsigned int epochs, unsigned int batch_size,
in.clear();
l.clear();

std::cout << "done"<<std::endl;

std::cout << "done" << std::endl;
}

std::array<UserDataType, 2>
createFakeDataGenerator(unsigned int batch_size,
unsigned int simulated_data_size,
unsigned int data_split) {
UserDataType train_data(new nntrainer::util::RandomDataLoader(
{{batch_size, 1, 1, 32}}, {{batch_size, 1, 1, 10}},
{{batch_size, 1, 1, 320}}, {{batch_size, 1, 1, 100}},
simulated_data_size / data_split));
UserDataType valid_data(new nntrainer::util::RandomDataLoader(
{{batch_size, 1, 1, 32}}, {{batch_size, 1, 1, 10}},
{{batch_size, 1, 1, 320}}, {{batch_size, 1, 1, 100}},
simulated_data_size / data_split));

return {std::move(train_data), std::move(valid_data)};
Expand Down Expand Up @@ -231,9 +227,9 @@ int main(int argc, char *argv[]) {

std::string data_dir = "fake";
uint batch_size = 1;
uint data_split =1;
uint data_split = 1;
uint epoch = 1;

std::array<UserDataType, 2> user_datas;

try {
Expand Down
12 changes: 6 additions & 6 deletions Applications/SimpleFC/jni/meson.build
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
resnet_sources = [
app_sources = [
'main.cpp',
cifar_path / 'cifar_dataloader.cpp'
]

resnet_dependencies = [app_utils_dep,
app_dependencies = [app_utils_dep,
iniparser_dep,
nntrainer_dep,
nntrainer_ccapi_dep
]

if get_option('enable-test')
resnet_dependencies += [gtest_dep]
app_dependencies += [gtest_dep]
endif

e = executable('nntrainer_simplefc',
resnet_sources,
app_sources,
include_directories: [include_directories('.'), cifar_include_dir],
dependencies: resnet_dependencies,
dependencies: app_dependencies,
install: get_option('install-app'),
install_dir: application_install_dir
)

if get_option('enable-long-test')
testenv = environment()
testenv.set('OPENBLAS_NUM_THREADS', '4')
test('app_resnet18', e, args: ['fake', '1', '128', '1'], env: testenv, timeout: 300)
test('app_asynch_fsu', e, args: ['fake', '1', '128', '1'], env: testenv, timeout: 300)
endif
1 change: 0 additions & 1 deletion meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ if get_option('enable-fp16')
extra_defines += '-DENABLE_FP16=1'
extra_defines += '-DUSE__FP16=1'
extra_defines += '-DUSE_NEON=1'
extra_defines += '-DUSE_MMAP=1'
elif arch == 'aarch64'
## About FP16 in GCC (from GCC-9.1 manual)
# https://gcc.gnu.org/onlinedocs/gcc-9.1.0/gcc/Half-Precision.html
Expand Down
32 changes: 29 additions & 3 deletions nntrainer/graph/network_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ int NetworkGraph::checkCompiledGraph() {
void NetworkGraph::markNodesForBackwarding() {
/** accumulate all the nodes which must support backwarding */
std::unordered_set<std::string> must_support_backwarding;
if (exec_mode == ExecutionMode::INFERENCE) {
for (auto iter = cbegin(); iter != cend(); iter++) {
auto lnode = (*iter);
lnode->needsCalcGradient(false);
lnode->needsCalcDerivative(false);
}
return;
}

/**
* if a node is trainable, then all the nodes ahead of it must support
Expand Down Expand Up @@ -867,14 +875,16 @@ NetworkGraph::finalizeContext(const std::shared_ptr<LayerNode> &lnode,
}
lnode->setDataType(init_context.getWeightDataType(),
init_context.getActivationDataType());

bool trainable = lnode->getTrainable();
if (exec_mode == ExecutionMode::INFERENCE)
trainable = false;
lnode->configureRunContext(
// TODO: update weights spec for trainable based on layer trainable prop
tensor_manager->requestWeights(gnode, init_context.getWeightsSpec(),
lnode->getTrainable(), shared_weight_names),
trainable, shared_weight_names),
inputs, outputs,
tensor_manager->requestTensors(gnode, init_context.getTensorsSpec(),
lnode->getTrainable(), shared_tensor_names),
trainable, shared_tensor_names),
init_context.getLossScale());

return outputs;
Expand Down Expand Up @@ -1552,6 +1562,22 @@ void NetworkGraph::flushCacheExcept(unsigned int order) {
tensor_manager->flushCacheExcept(order);
}

void NetworkGraph::LoadTensors(unsigned int order) {
tensor_manager->LoadTensors(order);
}

bool NetworkGraph::checkLoadComplete(unsigned int order) {
return tensor_manager->checkLoadComplete(order);
}

bool NetworkGraph::checkUnloadComplete(unsigned int order) {
return tensor_manager->checkUnloadComplete(order);
}

void NetworkGraph::UnloadTensors(unsigned int order) {
tensor_manager->UnloadTensors(order);
}

void NetworkGraph::requestOptimizerVariable(
std::function<std::vector<TensorDim>(const TensorDim &)> cb,
bool request_only_trainable) {
Expand Down
36 changes: 34 additions & 2 deletions nntrainer/graph/network_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,12 @@ class NetworkGraph {
* @brief Allocate memory for all the managed weights
*/
void allocateWeights(bool init = true) {
tensor_manager->allocateWeights(
std::get<3>(backward_iter_end->getExecutionOrder()), init);
unsigned int max_exec_order =
std::get<3>(backward_iter_end->getExecutionOrder());

if (exec_mode == ExecutionMode::INFERENCE)
max_exec_order = std::get<0>(forward_iter_end->getExecutionOrder());
tensor_manager->allocateWeights(max_exec_order, init);
}

/**
Expand Down Expand Up @@ -447,6 +451,34 @@ class NetworkGraph {
*/
void flushCacheExcept(const unsigned int order);

/**
* @brief Load data of order to the device
*
* @param order execution order
*/
void LoadTensors(const unsigned int order);

/**
* @brief check data of order is loaded
*
* @param order execution order
*/
bool checkLoadComplete(const unsigned int order);

/**
* @brief check data of order is Unloaded
*
* @param order execution order
*/
bool checkUnloadComplete(const unsigned int order);

/**
* @brief Load data of order to the device
*
* @param order execution order
*/
void UnloadTensors(const unsigned int order);

#ifdef ENABLE_TEST
/**
* @brief Get layer node's tenexecution orders
Expand Down
7 changes: 4 additions & 3 deletions nntrainer/layers/layer_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,8 @@ void LayerNode::read(std::ifstream &file, bool opt_var,

for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) {
/// @note shared weights are only be read at the first acecss
if (run_context->isGradientLastAccess(i)) {
// if (run_context->isGradientLastAccess(i)) {
if (run_context->isGradientFirstAccess(i)) {
if (layer->getType() == BatchNormalizationLayer::type) {
if ((mode == ml::train::ExecutionMode::TRAIN) &&
(this->getWeightDataType() != TensorDim::DataType::FP32)) {
Expand Down Expand Up @@ -526,7 +527,7 @@ void LayerNode::save(std::ofstream &file, bool opt_var,

if (opt_var) {
for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) {
if (run_context->isGradientLastAccess(i) && getTrainable()) {
if (run_context->isGradientFirstAccess(i) && getTrainable()) {
// @note save optimizer variables
if (run_context->weightHasGradient(i)) {
for (unsigned int j = 0; j < run_context->getNumWeightOptVar(i);
Expand All @@ -539,7 +540,7 @@ void LayerNode::save(std::ofstream &file, bool opt_var,
} else {
// @note shared weights are only be saved at the first access
for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) {
if (run_context->isGradientLastAccess(i)) {
if (run_context->isGradientFirstAccess(i)) {

/** @note For batch normalization layer, we do need full precision for
* training and the data type of weight is full precision. But for
Expand Down
2 changes: 0 additions & 2 deletions nntrainer/models/model_common_properties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ MemorySwap::MemorySwap(bool value) { set(value); }

MemorySwapPath::MemorySwapPath(const std::string &value) { set(value); }

MemorySwapMode::MemorySwapMode(const std::string &value) { set(value); }

MemorySwapLookahead::MemorySwapLookahead(const unsigned int &value) {
set(value);
}
Expand Down
Loading

0 comments on commit cd17a66

Please sign in to comment.