diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index c6f9c8a17d..7946d65a3e 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -39,6 +39,7 @@ enum LayerType { LAYER_WEIGHT = ML_TRAIN_LAYER_TYPE_WEIGHT, /**< Weight Layer type */ LAYER_ADD = ML_TRAIN_LAYER_TYPE_ADD, /**< Add Layer type */ LAYER_SUBTRACT = ML_TRAIN_LAYER_TYPE_SUBTRACT, /**< Subtract Layer type */ + LAYER_MULTIPLY = ML_TRAIN_LAYER_TYPE_MULTIPLY, /**< Multiply Layer type */ LAYER_FC = ML_TRAIN_LAYER_TYPE_FC, /**< Fully Connected Layer type */ LAYER_SWIGLU = ML_TRAIN_LAYER_TYPE_SWIGLU, /**< Swiglu Layer type */ LAYER_BN = ML_TRAIN_LAYER_TYPE_BN, /**< Batch Normalization Layer type */ @@ -323,6 +324,14 @@ SubtractLayer(const std::vector &properties = {}) { return createLayer(LayerType::LAYER_SUBTRACT, properties); } +/** + * @brief Helper function to create mul layer + */ +inline std::unique_ptr +MultiplyLayer(const std::vector &properties = {}) { + return createLayer(LayerType::LAYER_MULTIPLY, properties); +} + /** * @brief Helper function to create fully connected layer */ diff --git a/api/nntrainer-api-common.h b/api/nntrainer-api-common.h index 261c9eb853..64fcbf9910 100644 --- a/api/nntrainer-api-common.h +++ b/api/nntrainer-api-common.h @@ -67,6 +67,7 @@ typedef enum { ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/ ML_TRAIN_LAYER_TYPE_ADD = 32, /**< Add Layer type (Since 9.0)*/ ML_TRAIN_LAYER_TYPE_SUBTRACT = 33, /**< Subtract Layer type (Since 9.0)*/ + ML_TRAIN_LAYER_TYPE_MULTIPLY = 34, /**< Multiply Layer type (Since 9.0)*/ ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP = 300, /**< Preprocess flip Layer (Since 6.5) */ ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE = diff --git a/nntrainer/app_context.cpp b/nntrainer/app_context.cpp index d806e3138d..ec4e8ea038 100644 --- a/nntrainer/app_context.cpp +++ b/nntrainer/app_context.cpp @@ -60,6 +60,7 @@ #include #include #include +#include #include #include #include @@ -254,6 +255,8 @@ static void add_default_object(AppContext &ac) { LayerType::LAYER_ADD); ac.registerFactory(nntrainer::createLayer, SubtractLayer::type, LayerType::LAYER_SUBTRACT); + ac.registerFactory(nntrainer::createLayer, MultiplyLayer::type, + LayerType::LAYER_MULTIPLY); ac.registerFactory(nntrainer::createLayer, FullyConnectedLayer::type, LayerType::LAYER_FC); ac.registerFactory(nntrainer::createLayer, diff --git a/nntrainer/layers/meson.build b/nntrainer/layers/meson.build index 149fbaafea..91fc66b956 100644 --- a/nntrainer/layers/meson.build +++ b/nntrainer/layers/meson.build @@ -7,6 +7,7 @@ layer_sources = [ 'weight_layer.cpp', 'add_layer.cpp', 'subtract_layer.cpp', + 'multiply_layer.cpp', 'addition_layer.cpp', 'attention_layer.cpp', 'mol_attention_layer.cpp', diff --git a/nntrainer/layers/multiply_layer.cpp b/nntrainer/layers/multiply_layer.cpp new file mode 100644 index 0000000000..7815de8fee --- /dev/null +++ b/nntrainer/layers/multiply_layer.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file multiply_layer.cpp + * @date 10 Oct 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + * @brief This is mul layer class (operation layer) + * + */ + +#include +#include +#include +#include +#include + +#include + +namespace nntrainer { + +void MultiplyLayer::finalize(InitLayerContext &context) { + op_type = OperationType::BINARY; + context.setOutputDimensions({context.getInputDimensions()[0]}); +} + +void MultiplyLayer::forwarding_operation(const Tensor &input0, + const Tensor &input1, Tensor &hidden) { + input0.multiply(input1, hidden); +} + +void MultiplyLayer::calcDerivative(RunLayerContext &context) { + context.getOutgoingDerivative(0).copy( + context.getIncomingDerivative(SINGLE_INOUT_IDX) + .multiply(context.getInput(1))); + + context.getOutgoingDerivative(1).copy( + context.getIncomingDerivative(SINGLE_INOUT_IDX) + .multiply(context.getInput(0))); +} + +void MultiplyLayer::setProperty(const std::vector &values) { + auto remain_props = loadProperties(values, multiply_props); + if (!remain_props.empty()) { + std::string msg = "[MultiplyLayer] Unknown Layer Properties count " + + std::to_string(values.size()); + throw exception::not_supported(msg); + } +} +} /* namespace nntrainer */ diff --git a/nntrainer/layers/multiply_layer.h b/nntrainer/layers/multiply_layer.h new file mode 100644 index 0000000000..2345faa7a6 --- /dev/null +++ b/nntrainer/layers/multiply_layer.h @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file multiply_layer.h + * @date 10 Oct 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + * @brief This is mul layer class (operation layer) + * + */ + +#ifndef __MULTIPLY_LAYER_H__ +#define __MULTIPLY_LAYER_H__ +#ifdef __cplusplus + +#include +#include +#include + +namespace nntrainer { + +/** + * @class Multiply Layer + * @brief Multiply Layer + */ +class MultiplyLayer : public OperationLayer { +public: + /** + * @brief Constructor of Multiply Layer + */ + MultiplyLayer() : OperationLayer(), multiply_props(props::Print()) {} + + /** + * @brief Destructor of Multiply Layer + */ + ~MultiplyLayer(){}; + + /** + * @brief Move constructor of Multiply Layer. + * @param[in] MultiplyLayer && + */ + MultiplyLayer(MultiplyLayer &&rhs) noexcept = default; + + /** + * @brief Move assignment operator. + * @parma[in] rhs MultiplyLayer to be moved. + */ + MultiplyLayer &operator=(MultiplyLayer &&rhs) = default; + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(InitLayerContext &context) final; + + /** + * @copydoc OperationLayer::forwarding_operation(const Tensor &input, Tensor + * &hidden) + */ + void forwarding_operation(const Tensor &input, Tensor &hidden) final{}; + + /** + * @brief forwarding operation for add + * + * @param input0 input tensor 0 + * @param input1 input tensor 1 + * @param hidden tensor to store the result of addition + */ + void forwarding_operation(const Tensor &input0, const Tensor &input1, + Tensor &hidden) final; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(RunLayerContext &context) final; + + /** + * @copydoc bool supportBackwarding() const + */ + bool supportBackwarding() const final { return true; }; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods + * method) + */ + void exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const final {} + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) final; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const final { return MultiplyLayer::type; }; + + std::tuple multiply_props; + + inline static const std::string type = "multiply"; +}; + +} // namespace nntrainer + +#endif /* __cplusplus */ +#endif /* __MULTIPLY_LAYER_H__ */ diff --git a/test/ccapi/unittest_ccapi.cpp b/test/ccapi/unittest_ccapi.cpp index 86a978fc22..32ad53dd7e 100644 --- a/test/ccapi/unittest_ccapi.cpp +++ b/test/ccapi/unittest_ccapi.cpp @@ -70,6 +70,9 @@ TEST(ccapi_layer, construct_02_p) { EXPECT_NO_THROW(layer = ml::train::layer::SubtractLayer()); EXPECT_EQ(layer->getType(), "subtract"); + EXPECT_NO_THROW(layer = ml::train::layer::MultiplyLayer()); + EXPECT_EQ(layer->getType(), "multiply"); + EXPECT_NO_THROW(layer = ml::train::layer::FullyConnected()); EXPECT_EQ(layer->getType(), "fully_connected"); diff --git a/test/input_gen/genModelTests_v2.py b/test/input_gen/genModelTests_v2.py index 291f5e8151..dede100c48 100644 --- a/test/input_gen/genModelTests_v2.py +++ b/test/input_gen/genModelTests_v2.py @@ -455,6 +455,19 @@ def forward(self, inputs, labels): return out, loss +class MulOperation(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(2, 2) + self.loss = torch.nn.MSELoss() + + def forward(self, inputs, labels): + out = self.fc(inputs[0]) + out = inputs[0] * out + loss = self.loss(out, labels[0]) + return out, loss + + if __name__ == "__main__": record_v2( ReduceMeanLast(), @@ -749,5 +762,15 @@ def forward(self, inputs, labels): name="subtract_operation", ) + multiply_operation = MultiplyOperation() + record_v2( + multiply_operation, + iteration=2, + input_dims=[(1, 2)], + input_dtype=[float], + label_dims=[(1, 2)], + name="multiply_operation", + ) + # Function to check the created golden test file inspect_file("subtract_operation.nnmodelgolden") diff --git a/test/unittest/layers/meson.build b/test/unittest/layers/meson.build index dadf616bdc..fb6393161a 100644 --- a/test/unittest/layers/meson.build +++ b/test/unittest/layers/meson.build @@ -49,6 +49,7 @@ test_target = [ 'unittest_layers_addition.cpp', 'unittest_layers_add.cpp', 'unittest_layers_subtract.cpp', + 'unittest_layers_multiply.cpp', 'unittest_layers_multiout.cpp', 'unittest_layers_rnn.cpp', 'unittest_layers_rnncell.cpp', diff --git a/test/unittest/layers/unittest_layers_multiply.cpp b/test/unittest/layers/unittest_layers_multiply.cpp new file mode 100644 index 0000000000..f7492946bf --- /dev/null +++ b/test/unittest/layers/unittest_layers_multiply.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file unittest_layers_multiply.cpp + * @date 30 August 2024 + * @brief Mul Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + */ +#include + +#include + +#include +#include + +auto semantic_mul = LayerSemanticsParamType( + nntrainer::createLayer, + nntrainer::MultiplyLayer::type, {}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1); + +auto semantic_mul_multi = LayerSemanticsParamType( + nntrainer::createLayer, + nntrainer::MultiplyLayer::type, {}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 2); + +GTEST_PARAMETER_TEST(Mul, LayerSemantics, + ::testing::Values(semantic_mul, semantic_mul_multi)); diff --git a/test/unittest/models/unittest_models.cpp b/test/unittest/models/unittest_models.cpp index 6edbd71173..4846c47423 100644 --- a/test/unittest/models/unittest_models.cpp +++ b/test/unittest/models/unittest_models.cpp @@ -910,6 +910,25 @@ static std::unique_ptr makeSubOperation() { return nn; } +static std::unique_ptr makeMulOperation() { + std::unique_ptr nn(new NeuralNetwork()); + + auto outer_graph = + makeGraph({{"input", {"name=in", "input_shape=1:1:2"}}, + {"fully_connected", {"name=fc", "unit=2", "input_layers=in"}}, + {"multiply", {"name=multiply_layer", "input_layers=in,fc"}}, + {"mse", {"name=loss", "input_layers=multiply_layer"}}}); + + for (auto &node : outer_graph) { + nn->addLayer(node); + } + + nn->setProperty({"batch_size=1"}); + nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate=0.1"})); + + return nn; +} + GTEST_PARAMETER_TEST( model, nntrainerModelTest, ::testing::ValuesIn({ @@ -984,6 +1003,8 @@ GTEST_PARAMETER_TEST( mkModelTc_V2(makeAddOperation, "add_operation", ModelTestOption::ALL_V2), mkModelTc_V2(makeSubOperation, "subtract_operation", ModelTestOption::ALL_V2), + mkModelTc_V2(makeMulOperation, "multiply_operation", + ModelTestOption::ALL_V2), }), [](const testing::TestParamInfo &info) -> const auto & { return std::get<1>(info.param); });