diff --git a/.github/workflows/ubuntu_benchmarks.yml b/.github/workflows/ubuntu_benchmarks.yml new file mode 100644 index 0000000000..bbad098668 --- /dev/null +++ b/.github/workflows/ubuntu_benchmarks.yml @@ -0,0 +1,55 @@ +name: Ubuntu Benchmarks + +on: + schedule: + - cron: '0 2 * * *' + +jobs: + meson_test: + + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ ubuntu-20.04, ubuntu-22.04 ] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: install additional package from PPA for testing + run: sudo add-apt-repository -y ppa:nnstreamer/ppa && sudo apt-get update + - name: install minimal requirements + run: sudo apt-get update && sudo apt-get install -y gcc g++ pkg-config libopenblas-dev libiniparser-dev libjsoncpp-dev libcurl3-dev tensorflow2-lite-dev nnstreamer-dev libglib2.0-dev libgstreamer1.0-dev libgtest-dev ml-api-common-dev flatbuffers-compiler ml-inference-api-dev libunwind-dev libbenchmark-dev + - name: install additional packages for features + run: sudo apt-get install -y python3-dev python3-numpy python3 + - name: install build systems + run: sudo apt install meson ninja-build + - run: meson setup build/ + env: + CC: gcc + - run: | + meson \ + --buildtype=plain \ + --prefix=/usr \ + --sysconfdir=/etc \ + --libdir=lib/x86_64-linux-gnu \ + --bindir=lib/nntrainer/bin \ + --includedir=include \ + -Dinstall-app=false \ + -Dreduce-tolerance=false \ + -Denable-debug=true \ + -Dml-api-support=enabled \ + -Denable-nnstreamer-tensor-filter=enabled \ + -Denable-nnstreamer-tensor-trainer=enabled \ + -Denable-nnstreamer-backbone=true \ + -Dcapi-ml-common-actual=capi-ml-common \ + -Dcapi-ml-inference-actual=capi-ml-inference \ + -Denable-capi=enabled \ + -Denable-benchmarks=true \ + -Denable-app=false \ + build_benchmarks + - run: ninja -C build_benchmarks + - name: run Benchmarks_ResNet + run: cd ./build_benchmarks/benchmarks/benchmark_application && ./Benchmark_ResNet diff --git a/benchmarks/benchmark_application/benchmark_resnet.cpp b/benchmarks/benchmark_application/benchmark_resnet.cpp new file mode 100644 index 0000000000..f9024beaff --- /dev/null +++ b/benchmarks/benchmark_application/benchmark_resnet.cpp @@ -0,0 +1,305 @@ +/** + * Copyright (C) 2024 Donghak Park + * + * @file benchmark_resnet.cpp + * @date 15 Aug 2024 + * @brief benchmark test for resnet application + * @see https://github.com/nnstreamer/nntrainer + * @author Donghak Park + * @bug No known bugs except for NYI items + */ +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "benchmark/benchmark.h" +#include + +using LayerHandle = std::shared_ptr; +using ModelHandle = std::unique_ptr; + +using UserDataType = std::unique_ptr; + +uint64_t get_cpu_freq() { + unsigned int freq = 0; + char cur_cpu_name[512]; + int cpu = sched_getcpu(); + snprintf(cur_cpu_name, sizeof(cur_cpu_name), + "/sys/devices/system/cpu/cpu%d/cpufreq/scaling_cur_freq", cpu); + + FILE *f = fopen(cur_cpu_name, "r"); + if (f != nullptr) { + if (fscanf(f, "%d", &freq) != 0) { + fclose(f); + return uint64_t(freq) * 1000; + } + fclose(f); + } + return 0; +} + +/** cache loss values post training for test */ +float training_loss = 0.0; +float validation_loss = 0.0; + +/** + * @brief make "key=value" from key and value + * + * @tparam T type of a value + * @param key key + * @param value value + * @return std::string with "key=value" + */ +template +static std::string withKey(const std::string &key, const T &value) { + std::stringstream ss; + ss << key << "=" << value; + return ss.str(); +} + +template +static std::string withKey(const std::string &key, + std::initializer_list value) { + if (std::empty(value)) { + throw std::invalid_argument("empty data cannot be converted"); + } + + std::stringstream ss; + ss << key << "="; + + auto iter = value.begin(); + for (; iter != value.end() - 1; ++iter) { + ss << *iter << ','; + } + ss << *iter; + + return ss.str(); +} + +/** + * @brief resnet block + * + * @param block_name name of the block + * @param input_name name of the input + * @param filters number of filters + * @param kernel_size number of kernel_size + * @param downsample downsample to make output size 0 + * @return std::vector vectors of layers + */ +std::vector resnetBlock(const std::string &block_name, + const std::string &input_name, int filters, + int kernel_size, bool downsample) { + using ml::train::createLayer; + + auto scoped_name = [&block_name](const std::string &layer_name) { + return block_name + "/" + layer_name; + }; + auto with_name = [&scoped_name](const std::string &layer_name) { + return withKey("name", scoped_name(layer_name)); + }; + + auto create_conv = [&with_name, filters](const std::string &name, + int kernel_size, int stride, + const std::string &padding, + const std::string &input_layer) { + std::vector props{ + with_name(name), + withKey("stride", {stride, stride}), + withKey("filters", filters), + withKey("kernel_size", {kernel_size, kernel_size}), + withKey("padding", padding), + withKey("input_layers", input_layer)}; + + return createLayer("conv2d", props); + }; + + LayerHandle a1 = create_conv("a1", 3, downsample ? 2 : 1, + downsample ? "1,1" : "same", input_name); + LayerHandle a2 = + createLayer("batch_normalization", + {with_name("a2"), withKey("activation", "relu"), + withKey("momentum", "0.9"), withKey("epsilon", "0.00001")}); + LayerHandle a3 = create_conv("a3", 3, 1, "same", scoped_name("a2")); + + /** skip path */ + LayerHandle b1 = nullptr; + if (downsample) { + b1 = create_conv("b1", 1, 2, "0,0", input_name); + } + + const std::string skip_name = b1 ? scoped_name("b1") : input_name; + + LayerHandle c1 = createLayer( + "Addition", + {with_name("c1"), withKey("input_layers", {scoped_name("a3"), skip_name})}); + + LayerHandle c2 = + createLayer("batch_normalization", + {withKey("name", block_name), withKey("activation", "relu"), + withKey("momentum", "0.9"), withKey("epsilon", "0.00001"), + withKey("trainable", "false")}); + + if (downsample) { + return {b1, a1, a2, a3, c1, c2}; + } else { + return {a1, a2, a3, c1, c2}; + } +} + +/** + * @brief Create resnet 18 + * + * @return vector of layers that contain full graph of resnet18 + */ +std::vector createResnet18Graph() { + using ml::train::createLayer; + + std::vector layers; + + layers.push_back(createLayer( + "input", {withKey("name", "input0"), withKey("input_shape", "3:32:32")})); + + layers.push_back(createLayer( + "conv2d", {withKey("name", "conv0"), withKey("filters", 64), + withKey("kernel_size", {3, 3}), withKey("stride", {1, 1}), + withKey("padding", "same"), withKey("bias_initializer", "zeros"), + withKey("weight_initializer", "xavier_uniform")})); + + layers.push_back(createLayer( + "batch_normalization", + {withKey("name", "first_bn_relu"), withKey("activation", "relu"), + withKey("momentum", "0.9"), withKey("epsilon", "0.00001")})); + + std::vector> blocks; + + blocks.push_back(resnetBlock("conv1_0", "first_bn_relu", 64, 3, false)); + blocks.push_back(resnetBlock("conv1_1", "conv1_0", 64, 3, false)); + blocks.push_back(resnetBlock("conv2_0", "conv1_1", 128, 3, true)); + blocks.push_back(resnetBlock("conv2_1", "conv2_0", 128, 3, false)); + blocks.push_back(resnetBlock("conv3_0", "conv2_1", 256, 3, true)); + blocks.push_back(resnetBlock("conv3_1", "conv3_0", 256, 3, false)); + blocks.push_back(resnetBlock("conv4_0", "conv3_1", 512, 3, true)); + blocks.push_back(resnetBlock("conv4_1", "conv4_0", 512, 3, false)); + + for (auto &block : blocks) { + layers.insert(layers.end(), block.begin(), block.end()); + } + + layers.push_back(createLayer( + "pooling2d", {withKey("name", "last_p1"), withKey("pooling", "average"), + withKey("pool_size", {4, 4}), withKey("stride", "4,4")})); + + layers.push_back(createLayer("flatten", {withKey("name", "last_f1")})); + layers.push_back( + createLayer("fully_connected", + {withKey("unit", 100), withKey("activation", "softmax")})); + + return layers; +} + +ModelHandle createResnet18(bool pre_trained = false) { + ModelHandle model = ml::train::createModel(ml::train::ModelType::NEURAL_NET, + {withKey("loss", "cross")}); + + for (auto &layer : createResnet18Graph()) { + model->addLayer(layer); + } + + return model; +} + +int trainData_cb(float **input, float **label, bool *last, void *user_data) { + auto data = reinterpret_cast(user_data); + + data->next(input, label, last); + return 0; +} + +int validData_cb(float **input, float **label, bool *last, void *user_data) { + auto data = reinterpret_cast(user_data); + + data->next(input, label, last); + return 0; +} + +void createAndRun(unsigned int epochs, unsigned int batch_size, + UserDataType &train_user_data, + UserDataType &valid_user_data) { + + // setup model + ModelHandle model = createResnet18(); + model->setProperty( + {withKey("batch_size", batch_size), withKey("epochs", epochs)}); + + auto optimizer = ml::train::createOptimizer("adam", {"learning_rate=0.001"}); + + model->setOptimizer(std::move(optimizer)); + model->compile(); + model->initialize(); + + auto dataset_train = ml::train::createDataset( + ml::train::DatasetType::GENERATOR, trainData_cb, train_user_data.get()); + auto dataset_valid = ml::train::createDataset( + ml::train::DatasetType::GENERATOR, validData_cb, valid_user_data.get()); + + model->setDataset(ml::train::DatasetModeType::MODE_TRAIN, + std::move(dataset_train)); + model->setDataset(ml::train::DatasetModeType::MODE_VALID, + std::move(dataset_valid)); + + model->train(); +} + +std::array +createFakeDataGenerator(unsigned int batch_size, + unsigned int simulated_data_size, + unsigned int data_split) { + UserDataType train_data(new nntrainer::util::RandomDataLoader( + {{batch_size, 3, 32, 32}}, {{batch_size, 1, 1, 100}}, + simulated_data_size / data_split)); + UserDataType valid_data(new nntrainer::util::RandomDataLoader( + {{batch_size, 3, 32, 32}}, {{batch_size, 1, 1, 100}}, + simulated_data_size / data_split)); + + return {std::move(train_data), std::move(valid_data)}; +} + +std::array +createRealDataGenerator(const std::string &directory, unsigned int batch_size, + unsigned int data_split) { + + UserDataType train_data(new nntrainer::util::Cifar100DataLoader( + directory + "/train.bin", batch_size, data_split)); + UserDataType valid_data(new nntrainer::util::Cifar100DataLoader( + directory + "/test.bin", batch_size, data_split)); + + return {std::move(train_data), std::move(valid_data)}; +} + +static void Test_ResnetFull(benchmark::State &state) { + + unsigned int batch_size = 1; + unsigned int data_split = 128; + unsigned int epoch = 10; + + std::cout << "batch_size: " << batch_size << " data_split: " << data_split + << " epoch: " << epoch << std::endl; + + std::array user_datas; + user_datas = createFakeDataGenerator(batch_size, 512, data_split); + auto &[train_user_data, valid_user_data] = user_datas; + auto check_freq = get_cpu_freq(); + state.counters["check_freq"] = check_freq; + for (auto _ : state) { + createAndRun(epoch, batch_size, train_user_data, valid_user_data); + } +} + +BENCHMARK(Test_ResnetFull); +BENCHMARK_MAIN(); diff --git a/benchmarks/benchmark_application/meson.build b/benchmarks/benchmark_application/meson.build new file mode 100644 index 0000000000..1f3a386b65 --- /dev/null +++ b/benchmarks/benchmark_application/meson.build @@ -0,0 +1,13 @@ +build_root = meson.build_root() + +sources = ['benchmark_resnet.cpp', + fake_datagen_path / 'fake_data_gen.cpp'] + +resnet_dependencies = [nntrainer_dep, + nntrainer_ccapi_dep, + benchmark_dep, ] + +executable('Benchmark_ResNet', + sources, + include_directories : [include_directories('.'), fake_datagen_include_dir], + dependencies : resnet_dependencies) diff --git a/benchmarks/fake_data_gen/fake_data_gen.cpp b/benchmarks/fake_data_gen/fake_data_gen.cpp new file mode 100644 index 0000000000..fea4d3fa57 --- /dev/null +++ b/benchmarks/fake_data_gen/fake_data_gen.cpp @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2020 Jihoon Lee + * + * @file cifar_dataloader.h + * @date 24 Jun 2021s + * @brief dataloader for cifar + * @see https://github.com/nnstreamer/nntrainer + * @author Jihoon Lee + * @bug No known bugs except for NYI items + */ + +#include "fake_data_gen.h" + +#include +#include +#include +#include +#include + +namespace nntrainer::util { + +namespace { + +/** + * @brief fill label to the given memory + * + * @param data data to fill + * @param length size of the data + * @param label label + */ +void fillLabel(float *data, unsigned int length, unsigned int label) { + if (length == 1) { + *data = label; + return; + } + + memset(data, 0, length * sizeof(float)); + *(data + label) = 1; +} + +/** + * @brief fill last to the given memory + * @note this function increases iteration value, if last is set to true, + * iteration resets to 0 + * + * @param[in/out] iteration current iteration + * @param data_size Data size + * @return bool true if iteration has finished + */ +bool updateIteration(unsigned int &iteration, unsigned int data_size) { + if (iteration++ == data_size) { + iteration = 0; + return true; + } + return false; +}; + +} // namespace + +RandomDataLoader::RandomDataLoader(const std::vector &input_shapes, + const std::vector &output_shapes, + int data_size_) : + iteration(0), + data_size(data_size_), + input_shapes(input_shapes), + output_shapes(output_shapes), + input_dist(0, 255), + label_dist(0, output_shapes.front().width() - 1) { + NNTR_THROW_IF(output_shapes.empty(), std::invalid_argument) + << "output_shape size empty not supported"; + NNTR_THROW_IF(output_shapes.size() > 1, std::invalid_argument) + << "output_shape size > 1 is not supported"; +} + +void RandomDataLoader::next(float **input, float **label, bool *last) { + auto fill_input = [this](float *input, unsigned int length) { + for (unsigned int i = 0; i < length; ++i) { + *input = input_dist(rng); + input++; + } + }; + + auto fill_label = [this](float *label, unsigned int batch, + unsigned int length) { + unsigned int generated_label = label_dist(rng); + fillLabel(label, length, generated_label); + label += length; + }; + + if (updateIteration(iteration, data_size)) { + *last = true; + return; + } + + float **cur_input_tensor = input; + for (unsigned int i = 0; i < input_shapes.size(); ++i) { + fill_input(*cur_input_tensor, input_shapes.at(i).getFeatureLen()); + cur_input_tensor++; + } + + float **cur_label_tensor = label; + for (unsigned int i = 0; i < output_shapes.size(); ++i) { + fill_label(*label, output_shapes.at(i).batch(), + output_shapes.at(i).getFeatureLen()); + cur_label_tensor++; + } +} + +Cifar100DataLoader::Cifar100DataLoader(const std::string &path, int batch_size, + int splits) : + batch(batch_size), + current_iteration(0), + file(path, std::ios::binary | std::ios::ate) { + constexpr char error_msg[] = "failed to create dataloader, reason: "; + + NNTR_THROW_IF(!file.good(), std::invalid_argument) + << error_msg << " Cannot open file"; + + auto pos = file.tellg(); + NNTR_THROW_IF((pos % Cifar100DataLoader::SampleSize != 0), + std::invalid_argument) + << error_msg << " Given file does not align with the format"; + + auto data_size = pos / (Cifar100DataLoader::SampleSize * splits); + idxes = std::vector(data_size); + std::cout << "path: " << path << '\n'; + std::cout << "data_size: " << data_size << '\n'; + std::iota(idxes.begin(), idxes.end(), 0); + std::shuffle(idxes.begin(), idxes.end(), rng); + + /// @note this truncates the remaining data of less than the batch size + iteration_per_epoch = data_size; +} + +void Cifar100DataLoader::next(float **input, float **label, bool *last) { + /// @note below logic assumes a single input and the fine label is used + + auto fill_one_sample = [this](float *input_, float *label_, int index) { + const size_t error_buflen = 102; + char error_buf[error_buflen]; + NNTR_THROW_IF(!file.good(), std::invalid_argument) + << "file is not good, reason: " + << strerror_r(errno, error_buf, error_buflen); + file.seekg(index * Cifar100DataLoader::SampleSize, std::ios_base::beg); + + uint8_t current_label; + uint8_t fine_label; // it doesn't need for our application, so abandon it + file.read(reinterpret_cast(&fine_label), sizeof(uint8_t)); + file.read(reinterpret_cast(¤t_label), sizeof(uint8_t)); + + fillLabel(label_, Cifar100DataLoader::NumClass, current_label); + + for (unsigned int i = 0; i < Cifar100DataLoader::ImageSize; ++i) { + uint8_t data; + file.read(reinterpret_cast(&data), sizeof(uint8_t)); + *input_ = data / 255.f; + input_++; + } + }; + + fill_one_sample(*input, *label, idxes[current_iteration]); + current_iteration++; + if (current_iteration < iteration_per_epoch) { + *last = false; + } else { + *last = true; + current_iteration = 0; + std::shuffle(idxes.begin(), idxes.end(), rng); + } +} + +} // namespace nntrainer::util diff --git a/benchmarks/fake_data_gen/fake_data_gen.h b/benchmarks/fake_data_gen/fake_data_gen.h new file mode 100644 index 0000000000..10083620f6 --- /dev/null +++ b/benchmarks/fake_data_gen/fake_data_gen.h @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2020 Jihoon Lee + * + * @file cifar_dataloader.h + * @date 24 Jun 2021 + * @brief dataloader for cifar 100 + * @see https://github.com/nnstreamer/nntrainer + * @author Jihoon Lee + * @bug No known bugs except for NYI items + */ +#include + +#include +#include +#include +#include + +namespace nntrainer::util { + +using TensorDim = ml::train::TensorDim; + +/** + * @brief DataLoader interface used to load cifar data + */ +class DataLoader { +public: + /** + * @brief Destroy the Data Loader object + */ + virtual ~DataLoader() {} + + /** + * @brief create an iteration to fed to the generator callback + * + * @param[out] input list of inputs that is already allocated by nntrainer, + * and this function is obliged to fill + * @param[out] label list of label that is already allocated by nntrainer, and + * this function is obliged to fill + * @param[out] last optional property to set when the epoch has finished + */ + virtual void next(float **input, float **label, bool *last) = 0; + +protected: + std::mt19937 rng; +}; + +/** + * @brief RandomData Generator + * + */ +class RandomDataLoader final : public DataLoader { +public: + /** + * @brief Construct a new Random Data Loader object + * + * @param input_shapes input_shapes with appropriate batch + * @param output_shapes label_shapes with appropriate batch + * @param iteration iteration per epoch + */ + RandomDataLoader(const std::vector &input_shapes, + const std::vector &output_shapes, int iteration); + + /** + * @brief Destroy the Random Data Loader object + */ + ~RandomDataLoader() {} + + /** + * @copydoc void DataLoader::next(float **input, float**label, bool *last) + */ + void next(float **input, float **label, bool *last); + +private: + unsigned int iteration; + unsigned int data_size; + + std::vector input_shapes; + std::vector output_shapes; + + std::uniform_int_distribution input_dist; + std::uniform_int_distribution label_dist; +}; + +/** + * @brief Cifar100DataLoader class + */ +class Cifar100DataLoader final : public DataLoader { +public: + /** + * @brief Construct a new Cifar100 Data Loader object + * + * @param path path to read from + * @param batch_size batch_size of current model + * @param splits split divisor of the file 1 means using whole data, 2 means + * half of the data, 10 means 10% of the data + */ + Cifar100DataLoader(const std::string &path, int batch_size, int splits); + + /** + * @brief Destroy the Cifar100 Data Loader object + */ + ~Cifar100DataLoader() {} + + /** + * @copydoc void DataLoader::next(float **input, float**label, bool *last) + */ + void next(float **input, float **label, bool *last); + +private: + inline static constexpr int ImageSize = 3 * 32 * 32; + inline static constexpr int NumClass = 100; + inline static constexpr int SampleSize = + 4 * (3 * 32 * 32 + 2); /**< 1 coarse label, 1 fine label, pixel size */ + + unsigned int batch; + unsigned int current_iteration; + unsigned int iteration_per_epoch; + + std::ifstream file; + std::vector idxes; /**< index information for one epoch */ +}; + +} // namespace nntrainer::util diff --git a/benchmarks/fake_data_gen/meson.build b/benchmarks/fake_data_gen/meson.build new file mode 100644 index 0000000000..945bd74dc6 --- /dev/null +++ b/benchmarks/fake_data_gen/meson.build @@ -0,0 +1,2 @@ +fake_datagen_path = meson.current_source_dir() +fake_datagen_include_dir = include_directories('.') diff --git a/benchmarks/meson.build b/benchmarks/meson.build new file mode 100644 index 0000000000..026677d0c6 --- /dev/null +++ b/benchmarks/meson.build @@ -0,0 +1,2 @@ +subdir('fake_data_gen') +subdir('benchmark_application') diff --git a/meson.build b/meson.build index 815a47f576..d362d35362 100644 --- a/meson.build +++ b/meson.build @@ -314,7 +314,7 @@ endif gmock_dep = dependency('gmock', static: true, main: false, required: false) gtest_dep = dependency('gtest', static: true, main: false, required: false) gtest_main_dep = dependency('gtest', static: true, main: true, required: false) - +benchmark_dep = dependency('benchmark', static : true, main : false, required : false) if get_option('enable-test') # and get_option('platform') != 'android' extra_defines += '-DENABLE_TEST=1' @@ -495,3 +495,7 @@ endif if get_option('platform') != 'none' message('building for ' + get_option('platform')) endif + +if get_option('enable-benchmarks') + subdir('benchmarks') +endif diff --git a/meson_options.txt b/meson_options.txt index 316d8f2e1f..e904d9de35 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -44,6 +44,7 @@ option('enable-neon', type: 'boolean', value: false) option('enable-avx', type: 'boolean', value: true) option('enable-opencl', type: 'boolean', value: false) option('enable-biqgemm', type: 'boolean', value: false) +option('enable-benchmarks', type: 'boolean', value : false) # ml-api dependency (to enable, install capi-inference from github.com/nnstreamer/api ) # To inter-operate with nnstreamer and ML-API packages, you need to enable this.