From ea672ddc55a1f9449901a4bfb0640ea667472936 Mon Sep 17 00:00:00 2001 From: "jijoong.moon" Date: Wed, 11 Dec 2024 21:59:14 +0900 Subject: [PATCH] Add ComputeEngine Property for choosing Engine This PR add ComputeEngine Enum Property. Enum elements are "cpu", "gpu", "qnn" for now. The property format is "engine=qnn". **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: jijoong.moon --- api/ccapi/include/common.h | 9 +++ api/ccapi/include/layer.h | 51 +++++-------- api/ccapi/src/factory.cpp | 10 ++- nntrainer/layers/common_properties.h | 24 ++++++- nntrainer/layers/layer_node.cpp | 72 +++++++++++++------ nntrainer/layers/layer_node.h | 24 +++---- nntrainer/utils/node_exporter.cpp | 2 +- .../layers/layers_dependent_common_tests.cpp | 15 ++-- 8 files changed, 119 insertions(+), 88 deletions(-) diff --git a/api/ccapi/include/common.h b/api/ccapi/include/common.h index f1ea0101e0..d1ce4cf25b 100644 --- a/api/ccapi/include/common.h +++ b/api/ccapi/include/common.h @@ -43,6 +43,15 @@ enum class ExecutionMode { VALIDATE /** Validate mode, label is necessary */ }; +/** + * @brief Enumeration of layer compute engine + */ +enum LayerComputeEngine { + CPU, /**< CPU as the compute engine */ + GPU, /**< GPU as the compute engine */ + QNN, /**< QNN as the compute engine */ +}; + /** * @brief Get the version of NNTrainer */ diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index e1a6885c7c..5405d2973d 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -114,14 +114,6 @@ enum LayerType { LAYER_UNKNOWN = ML_TRAIN_LAYER_TYPE_UNKNOWN /**< Unknown */ }; -/** - * @brief Enumeration of layer compute engine - */ -enum LayerComputeEngine { - CPU, /**< CPU as the compute engine */ - GPU, /**< GPU as the compute engine */ -}; - /** * @class Layer Base class for layers * @brief Base class for all layers @@ -261,16 +253,14 @@ class Layer { */ std::unique_ptr createLayer(const LayerType &type, - const std::vector &properties = {}, - const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU); + const std::vector &properties = {}); /** * @brief Factory creator with constructor for layer */ std::unique_ptr createLayer(const std::string &type, - const std::vector &properties = {}, - const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU); + const std::vector &properties = {}); /** * @brief General Layer Factory function to register Layer @@ -343,37 +333,33 @@ DivideLayer(const std::vector &properties = {}) { /** * @brief Helper function to create fully connected layer */ -inline std::unique_ptr FullyConnected( - const std::vector &properties = {}, - const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { - return createLayer(LayerType::LAYER_FC, properties, compute_engine); +inline std::unique_ptr +FullyConnected(const std::vector &properties = {}) { + return createLayer(LayerType::LAYER_FC, properties); } /** * @brief Helper function to create Swiglu layer */ inline std::unique_ptr -Swiglu(const std::vector &properties = {}, - const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { - return createLayer(LayerType::LAYER_SWIGLU, properties, compute_engine); +Swiglu(const std::vector &properties = {}) { + return createLayer(LayerType::LAYER_SWIGLU, properties); } /** * @brief Helper function to create RMS normalization layer for GPU */ inline std::unique_ptr -RMSNormCl(const std::vector &properties = {}, - const LayerComputeEngine &compute_engine = LayerComputeEngine::GPU) { - return createLayer(LayerType::LAYER_RMSNORM, properties, compute_engine); +RMSNormCl(const std::vector &properties = {}) { + return createLayer(LayerType::LAYER_RMSNORM, properties); } /** * @brief Helper function to create Transpose layer */ inline std::unique_ptr -Transpose(const std::vector &properties = {}, - const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { - return createLayer(LayerType::LAYER_TRANSPOSE, properties, compute_engine); +Transpose(const std::vector &properties = {}) { + return createLayer(LayerType::LAYER_TRANSPOSE, properties); } /** @@ -428,27 +414,24 @@ Flatten(const std::vector &properties = {}) { * @brief Helper function to create reshape layer */ inline std::unique_ptr -Reshape(const std::vector &properties = {}, - const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { - return createLayer(LayerType::LAYER_RESHAPE, properties, compute_engine); +Reshape(const std::vector &properties = {}) { + return createLayer(LayerType::LAYER_RESHAPE, properties); } /** * @brief Helper function to create addition layer */ inline std::unique_ptr -Addition(const std::vector &properties = {}, - const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { - return createLayer(LayerType::LAYER_ADDITION, properties, compute_engine); +Addition(const std::vector &properties = {}) { + return createLayer(LayerType::LAYER_ADDITION, properties); } /** * @brief Helper function to create concat layer */ inline std::unique_ptr -Concat(const std::vector &properties = {}, - const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { - return createLayer(LayerType::LAYER_CONCAT, properties, compute_engine); +Concat(const std::vector &properties = {}) { + return createLayer(LayerType::LAYER_CONCAT, properties); } /** diff --git a/api/ccapi/src/factory.cpp b/api/ccapi/src/factory.cpp index 5f2b2dd2b9..daad2c7b6f 100644 --- a/api/ccapi/src/factory.cpp +++ b/api/ccapi/src/factory.cpp @@ -31,18 +31,16 @@ namespace ml { namespace train { std::unique_ptr createLayer(const LayerType &type, - const std::vector &properties, - const LayerComputeEngine &compute_engine) { - return nntrainer::createLayerNode(type, properties, compute_engine); + const std::vector &properties) { + return nntrainer::createLayerNode(type, properties); } /** * @brief Factory creator with constructor for layer */ std::unique_ptr createLayer(const std::string &type, - const std::vector &properties, - const LayerComputeEngine &compute_engine) { - return nntrainer::createLayerNode(type, properties, compute_engine); + const std::vector &properties) { + return nntrainer::createLayerNode(type, properties); } std::unique_ptr diff --git a/nntrainer/layers/common_properties.h b/nntrainer/layers/common_properties.h index 838cb6fdd5..482cbedb23 100644 --- a/nntrainer/layers/common_properties.h +++ b/nntrainer/layers/common_properties.h @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -945,12 +946,33 @@ struct ActivationTypeInfo { * @brief Activation Enumeration Information * */ -class Activation final : public EnumProperty { +class Activation final + : public EnumProperty { public: using prop_tag = enum_class_prop_tag; static constexpr const char *key = "activation"; }; +/** + * @brief Enumeration of Run Engine type + */ +struct ComputeEngineTypeInfo { + using Enum = ml::train::LayerComputeEngine; + static constexpr std::initializer_list EnumList = {Enum::CPU, Enum::GPU, + Enum::QNN}; + static constexpr const char *EnumStr[] = {"cpu", "gpu", "qnn"}; +}; + +/** + * @brief ComputeEngine Enumeration Information + * + */ +class ComputeEngine final : public EnumProperty { +public: + using prop_tag = enum_class_prop_tag; + static constexpr const char *key = "engine"; +}; + /** * @brief HiddenStateActivation Enumeration Information * diff --git a/nntrainer/layers/layer_node.cpp b/nntrainer/layers/layer_node.cpp index 9c6c290703..ba223bb3e1 100644 --- a/nntrainer/layers/layer_node.cpp +++ b/nntrainer/layers/layer_node.cpp @@ -130,20 +130,46 @@ class SharedFrom : public Name { */ LayerNode::~LayerNode() = default; +/** + * @brief get the compute engine property from property string vector + * : default is CPU + * @return LayerComputeEngine Enum : CPU, GPU, QNN + * + */ +ml::train::LayerComputeEngine +getComputeEngine(const std::vector &props) { + for (auto &prop : props) { + std::string key, value; + int status = nntrainer::getKeyValue(prop, key, value); + if (nntrainer::istrequal(key, "engine")) { + if (nntrainer::istrequal(value, "qnn")) { + return ml::train::LayerComputeEngine::QNN; + } else if (nntrainer::istrequal(value, "gpu")) { + return ml::train::LayerComputeEngine::GPU; + } + } + } + + return ml::train::LayerComputeEngine::CPU; +} + /** * @brief Layer factory creator with constructor */ std::unique_ptr createLayerNode(const ml::train::LayerType &type, - const std::vector &properties, - const ml::train::LayerComputeEngine &compute_engine) { + const std::vector &properties) { + + if (getComputeEngine(properties) == ml::train::LayerComputeEngine::GPU) { #ifdef ENABLE_OPENCL - if (compute_engine == ml::train::LayerComputeEngine::GPU) { auto &cc = nntrainer::ClContext::Global(); - return createLayerNode(cc.createObject(type), properties, - compute_engine); - } + return createLayerNode(cc.createObject(type), properties); +#else + throw std::invalid_argument( + "opencl layer creation without enable-opencl option"); #endif + } + auto &ac = nntrainer::AppContext::Global(); return createLayerNode(ac.createObject(type), properties); } @@ -153,15 +179,18 @@ createLayerNode(const ml::train::LayerType &type, */ std::unique_ptr createLayerNode(const std::string &type, - const std::vector &properties, - const ml::train::LayerComputeEngine &compute_engine) { + const std::vector &properties) { + + if (getComputeEngine(properties) == ml::train::LayerComputeEngine::GPU) { #ifdef ENABLE_OPENCL - if (compute_engine == ml::train::LayerComputeEngine::GPU) { auto &cc = nntrainer::ClContext::Global(); - return createLayerNode(cc.createObject(type), properties, - compute_engine); - } + return createLayerNode(cc.createObject(type), properties); +#else + throw std::invalid_argument( + "opencl layer creation without enable-opencl option"); #endif + } + auto &ac = nntrainer::AppContext::Global(); return createLayerNode(ac.createObject(type), properties); } @@ -171,16 +200,11 @@ createLayerNode(const std::string &type, */ std::unique_ptr createLayerNode(std::unique_ptr &&layer, - const std::vector &properties, - const ml::train::LayerComputeEngine &compute_engine) { + const std::vector &properties) { auto lnode = std::make_unique(std::move(layer)); lnode->setProperty(properties); - if (compute_engine == ml::train::LayerComputeEngine::GPU) { - lnode->setComputeEngine(compute_engine); - } - return lnode; } @@ -192,10 +216,10 @@ LayerNode::LayerNode(std::unique_ptr &&l) : output_connections(), run_context(nullptr), - layer_node_props( - new PropsType(props::Name(), props::Distribute(), props::Trainable(), {}, - {}, props::SharedFrom(), props::ClipGradByGlobalNorm(), - props::Packed(), props::LossScaleForMixed())), + layer_node_props(new PropsType( + props::Name(), props::Distribute(), props::Trainable(), {}, {}, + props::SharedFrom(), props::ClipGradByGlobalNorm(), props::Packed(), + props::LossScaleForMixed(), props::ComputeEngine())), layer_node_props_realization( new RealizationPropsType(props::Flatten(), props::Activation())), loss(new props::Loss()), @@ -670,6 +694,10 @@ InitLayerContext LayerNode::finalize(const std::vector &input_dims, if (!std::get(*layer_node_props).empty()) loss_scale = std::get(*layer_node_props).get(); + if (!std::get(*layer_node_props).empty()) { + compute_engine = std::get(*layer_node_props).get(); + } + if (!std::get(*layer_node_props).empty()) { bool isPacked = std::get(*layer_node_props); if (!isPacked) { diff --git a/nntrainer/layers/layer_node.h b/nntrainer/layers/layer_node.h index 9ed105d28a..49a2f230b0 100644 --- a/nntrainer/layers/layer_node.h +++ b/nntrainer/layers/layer_node.h @@ -53,6 +53,7 @@ class InputConnection; class ClipGradByGlobalNorm; class Packed; class LossScaleForMixed; +class ComputeEngine; } // namespace props /** @@ -994,11 +995,12 @@ will also contain the properties of the layer. The properties will be copied upon final creation. Editing properties of the layer after init will not the properties in the context/graph unless intended. */ - using PropsType = std::tuple, - std::vector, - props::SharedFrom, props::ClipGradByGlobalNorm, - props::Packed, props::LossScaleForMixed>; + using PropsType = + std::tuple, + std::vector, props::SharedFrom, + props::ClipGradByGlobalNorm, props::Packed, + props::LossScaleForMixed, props::ComputeEngine>; using RealizationPropsType = std::tuple; /** these realization properties results in addition of new layers, hence @@ -1070,9 +1072,7 @@ properties in the context/graph unless intended. */ */ std::unique_ptr createLayerNode(const ml::train::LayerType &type, - const std::vector &properties = {}, - const ml::train::LayerComputeEngine &compute_engine = - ml::train::LayerComputeEngine::CPU); + const std::vector &properties = {}); /** * @brief LayerNode creator with constructor @@ -1082,9 +1082,7 @@ createLayerNode(const ml::train::LayerType &type, */ std::unique_ptr createLayerNode(const std::string &type, - const std::vector &properties = {}, - const ml::train::LayerComputeEngine &compute_engine = - ml::train::LayerComputeEngine::CPU); + const std::vector &properties = {}); /** * @brief LayerNode creator with constructor @@ -1095,9 +1093,7 @@ createLayerNode(const std::string &type, */ std::unique_ptr createLayerNode(std::unique_ptr &&layer, - const std::vector &properties, - const ml::train::LayerComputeEngine &compute_engine = - ml::train::LayerComputeEngine::CPU); + const std::vector &properties); } // namespace nntrainer #endif // __LAYER_NODE_H__ diff --git a/nntrainer/utils/node_exporter.cpp b/nntrainer/utils/node_exporter.cpp index 031d2c2fbf..412ad30ac9 100644 --- a/nntrainer/utils/node_exporter.cpp +++ b/nntrainer/utils/node_exporter.cpp @@ -92,7 +92,7 @@ void Exporter::saveTflResult( std::vector, std::vector, props::SharedFrom, props::ClipGradByGlobalNorm, props::Packed, - props::LossScaleForMixed> &props, + props::LossScaleForMixed, props::ComputeEngine> &props, const LayerNode *self) { createIfNull(tf_node); tf_node->setLayerNode(*self); diff --git a/test/unittest/layers/layers_dependent_common_tests.cpp b/test/unittest/layers/layers_dependent_common_tests.cpp index d3c0666031..0290cb9907 100644 --- a/test/unittest/layers/layers_dependent_common_tests.cpp +++ b/test/unittest/layers/layers_dependent_common_tests.cpp @@ -143,15 +143,13 @@ TEST_P(LayerSemanticsGpu, createFromClContext_pn) { // } TEST_P(LayerSemanticsGpu, setPropertiesInvalid_n) { - auto lnode = - nntrainer::createLayerNode(expected_type, {}, ComputeEngine::GPU); + auto lnode = nntrainer::createLayerNode(expected_type, {"engine=gpu"}); /** must not crash */ EXPECT_THROW(layer->setProperty({"unknown_props=2"}), std::invalid_argument); } TEST_P(LayerSemanticsGpu, finalizeValidateLayerNode_p) { - auto lnode = - nntrainer::createLayerNode(expected_type, {}, ComputeEngine::GPU); + auto lnode = nntrainer::createLayerNode(expected_type, {"engine=gpu"}); std::vector props = {"name=test"}; std::string input_shape = "input_shape=1:1:1"; std::string input_layers = "input_layers=a"; @@ -181,8 +179,7 @@ TEST_P(LayerSemanticsGpu, finalizeValidateLayerNode_p) { } TEST_P(LayerSemanticsGpu, getTypeValidateLayerNode_p) { - auto lnode = - nntrainer::createLayerNode(expected_type, {}, ComputeEngine::GPU); + auto lnode = nntrainer::createLayerNode(expected_type, {"engine=gpu"}); std::string type; EXPECT_NO_THROW(type = lnode->getType()); @@ -190,8 +187,7 @@ TEST_P(LayerSemanticsGpu, getTypeValidateLayerNode_p) { } TEST_P(LayerSemanticsGpu, gettersValidateLayerNode_p) { - auto lnode = - nntrainer::createLayerNode(expected_type, {}, ComputeEngine::GPU); + auto lnode = nntrainer::createLayerNode(expected_type, {"engine=gpu"}); EXPECT_NO_THROW(lnode->supportInPlace()); EXPECT_NO_THROW(lnode->requireLabel()); @@ -199,8 +195,7 @@ TEST_P(LayerSemanticsGpu, gettersValidateLayerNode_p) { } TEST_P(LayerSemanticsGpu, setBatchValidateLayerNode_p) { - auto lnode = - nntrainer::createLayerNode(expected_type, {}, ComputeEngine::GPU); + auto lnode = nntrainer::createLayerNode(expected_type, {"engine=gpu"}); std::vector props = {"name=test"}; std::string input_shape = "input_shape=1:1:1"; std::string input_layers = "input_layers=a";