Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wait for #2724][Layer] add "multiply layer" @open sesame 11/12 10:52 #2725

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/ccapi/include/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum LayerType {
LAYER_WEIGHT = ML_TRAIN_LAYER_TYPE_WEIGHT, /**< Weight Layer type */
LAYER_ADD = ML_TRAIN_LAYER_TYPE_ADD, /**< Add Layer type */
LAYER_SUBTRACT = ML_TRAIN_LAYER_TYPE_SUBTRACT, /**< Subtract Layer type */
LAYER_MULTIPLY = ML_TRAIN_LAYER_TYPE_MULTIPLY, /**< Multiply Layer type */
LAYER_FC = ML_TRAIN_LAYER_TYPE_FC, /**< Fully Connected Layer type */
LAYER_SWIGLU = ML_TRAIN_LAYER_TYPE_SWIGLU, /**< Swiglu Layer type */
LAYER_BN = ML_TRAIN_LAYER_TYPE_BN, /**< Batch Normalization Layer type */
Expand Down Expand Up @@ -323,6 +324,14 @@ SubtractLayer(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_SUBTRACT, properties);
}

/**
* @brief Helper function to create multiply layer
*/
inline std::unique_ptr<Layer>
MultiplyLayer(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_MULTIPLY, properties);
}

/**
* @brief Helper function to create fully connected layer
*/
Expand Down
1 change: 1 addition & 0 deletions api/nntrainer-api-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ typedef enum {
ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_ADD = 32, /**< Add Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_SUBTRACT = 33, /**< Subtract Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_MULTIPLY = 34, /**< Multiply Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP =
300, /**< Preprocess flip Layer (Since 6.5) */
ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE =
Expand Down
3 changes: 3 additions & 0 deletions nntrainer/app_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#include <mse_loss_layer.h>
#include <multi_head_attention_layer.h>
#include <multiout_layer.h>
#include <multiply_layer.h>
#include <nntrainer_error.h>
#include <permute_layer.h>
#include <plugged_layer.h>
Expand Down Expand Up @@ -259,6 +260,8 @@ static void add_default_object(AppContext &ac) {
LayerType::LAYER_ADD);
ac.registerFactory(nntrainer::createLayer<SubtractLayer>, SubtractLayer::type,
LayerType::LAYER_SUBTRACT);
ac.registerFactory(nntrainer::createLayer<MultiplyLayer>, MultiplyLayer::type,
LayerType::LAYER_MULTIPLY);
ac.registerFactory(nntrainer::createLayer<FullyConnectedLayer>,
FullyConnectedLayer::type, LayerType::LAYER_FC);
ac.registerFactory(nntrainer::createLayer<BatchNormalizationLayer>,
Expand Down
1 change: 1 addition & 0 deletions nntrainer/layers/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ layer_sources = [
'weight_layer.cpp',
'add_layer.cpp',
'subtract_layer.cpp',
'multiply_layer.cpp',
'addition_layer.cpp',
'attention_layer.cpp',
'mol_attention_layer.cpp',
Expand Down
51 changes: 51 additions & 0 deletions nntrainer/layers/multiply_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 SeungBaek Hong <[email protected]>
*
* @file multiply_layer.cpp
* @date 10 Oct 2024
* @see https://github.com/nnstreamer/nntrainer
* @author SeungBaek Hong <[email protected]>
* @bug No known bugs except for NYI items
* @brief This is multiply layer class (operation layer)
*
*/

#include <multiply_layer.h>
#include <nntrainer_error.h>
#include <nntrainer_log.h>
#include <node_exporter.h>
#include <util_func.h>

#include <layer_context.h>

namespace nntrainer {

void MultiplyLayer::finalize(InitLayerContext &context) {
context.setOutputDimensions({context.getInputDimensions()[0]});
}

void MultiplyLayer::forwarding_operation(const Tensor &input0,
const Tensor &input1, Tensor &hidden) {
input0.multiply(input1, hidden);
}

void MultiplyLayer::calcDerivative(RunLayerContext &context) {
context.getOutgoingDerivative(0).copy(
context.getIncomingDerivative(SINGLE_INOUT_IDX)
.multiply(context.getInput(1)));

context.getOutgoingDerivative(1).copy(
context.getIncomingDerivative(SINGLE_INOUT_IDX)
.multiply(context.getInput(0)));
}

void MultiplyLayer::setProperty(const std::vector<std::string> &values) {
auto remain_props = loadProperties(values, multiply_props);
if (!remain_props.empty()) {
std::string msg = "[MultiplyLayer] Unknown Layer Properties count " +
std::to_string(values.size());
throw exception::not_supported(msg);
}
}
} /* namespace nntrainer */
102 changes: 102 additions & 0 deletions nntrainer/layers/multiply_layer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 SeungBaek Hong <[email protected]>
*
* @file multiply_layer.h
* @date 10 Oct 2024
* @see https://github.com/nnstreamer/nntrainer
* @author SeungBaek Hong <[email protected]>
* @bug No known bugs except for NYI items
* @brief This is mul layer class (operation layer)
*
*/

#ifndef __MULTIPLY_LAYER_H__
#define __MULTIPLY_LAYER_H__
#ifdef __cplusplus

#include <common_properties.h>
#include <layer_devel.h>
#include <operation_layer.h>

namespace nntrainer {

/**
* @class Multiply Layer
* @brief Multiply Layer
*/
class MultiplyLayer : public BinaryOperationLayer {
public:
/**
* @brief Constructor of Multiply Layer
*/
MultiplyLayer() : BinaryOperationLayer(), multiply_props(props::Print()) {}

/**
* @brief Destructor of Multiply Layer
*/
~MultiplyLayer(){};

/**
* @brief Move constructor of Multiply Layer.
* @param[in] MultiplyLayer &&
*/
MultiplyLayer(MultiplyLayer &&rhs) noexcept = default;

/**
* @brief Move assignment operator.
* @parma[in] rhs MultiplyLayer to be moved.
*/
MultiplyLayer &operator=(MultiplyLayer &&rhs) = default;

/**
* @copydoc Layer::finalize(InitLayerContext &context)
*/
void finalize(InitLayerContext &context) final;

/**
* @brief forwarding operation for add
*
* @param input0 input tensor 0
* @param input1 input tensor 1
* @param hidden tensor to store the result of addition
*/
void forwarding_operation(const Tensor &input0, const Tensor &input1,
Tensor &hidden) final;

/**
* @copydoc Layer::calcDerivative(RunLayerContext &context)
*/
void calcDerivative(RunLayerContext &context) final;

/**
* @copydoc bool supportBackwarding() const
*/
bool supportBackwarding() const final { return true; };

/**
* @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
* method)
*/
void exportTo(Exporter &exporter,
const ml::train::ExportMethods &method) const final {}

/**
* @copydoc Layer::setProperty(const std::vector<std::string> &values)
*/
void setProperty(const std::vector<std::string> &values) final;

/**
* @copydoc Layer::getType()
*/
const std::string getType() const final { return MultiplyLayer::type; };

std::tuple<props::Print> multiply_props;

inline static const std::string type = "multiply";
};

} // namespace nntrainer

#endif /* __cplusplus */
#endif /* __MULTIPLY_LAYER_H__ */
3 changes: 3 additions & 0 deletions test/ccapi/unittest_ccapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ TEST(ccapi_layer, construct_02_p) {
EXPECT_NO_THROW(layer = ml::train::layer::SubtractLayer());
EXPECT_EQ(layer->getType(), "subtract");

EXPECT_NO_THROW(layer = ml::train::layer::MultiplyLayer());
EXPECT_EQ(layer->getType(), "multiply");

EXPECT_NO_THROW(layer = ml::train::layer::FullyConnected());
EXPECT_EQ(layer->getType(), "fully_connected");

Expand Down
23 changes: 23 additions & 0 deletions test/input_gen/genModelTests_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,19 @@ def forward(self, inputs, labels):
return out, loss


class MultiplyOperation(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(2, 2)
self.loss = torch.nn.MSELoss()

def forward(self, inputs, labels):
out = self.fc(inputs[0])
out = inputs[0] * out
loss = self.loss(out, labels[0])
return out, loss


if __name__ == "__main__":
record_v2(
ReduceMeanLast(),
Expand Down Expand Up @@ -799,6 +812,16 @@ def forward(self, inputs, labels):
name="subtract_operation",
)

multiply_operation = MultiplyOperation()
record_v2(
multiply_operation,
iteration=2,
input_dims=[(1, 2)],
input_dtype=[float],
label_dims=[(1, 2)],
name="multiply_operation",
)

# Function to check the created golden test file
inspect_file("add_operation.nnmodelgolden")
fc_mixed_training_nan_sgd = LinearMixedPrecisionNaNSGD()
Expand Down
1 change: 1 addition & 0 deletions test/unittest/layers/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ test_target = [
'unittest_layers_addition.cpp',
'unittest_layers_add.cpp',
'unittest_layers_subtract.cpp',
'unittest_layers_multiply.cpp',
'unittest_layers_multiout.cpp',
'unittest_layers_rnn.cpp',
'unittest_layers_rnncell.cpp',
Expand Down
31 changes: 31 additions & 0 deletions test/unittest/layers/unittest_layers_multiply.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 SeungBaek Hong <[email protected]>
*
* @file unittest_layers_multiply.cpp
* @date 30 August 2024
* @brief Mul Layer Test
* @see https://github.com/nnstreamer/nntrainer
* @author SeungBaek Hong <[email protected]>
* @bug No known bugs except for NYI items
*/
#include <tuple>

#include <gtest/gtest.h>

#include <layers_common_tests.h>
#include <multiply_layer.h>

auto semantic_multiply = LayerSemanticsParamType(
nntrainer::createLayer<nntrainer::MultiplyLayer>,
nntrainer::MultiplyLayer::type, {},
LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1);

auto semantic_multiply_multi = LayerSemanticsParamType(
nntrainer::createLayer<nntrainer::MultiplyLayer>,
nntrainer::MultiplyLayer::type, {},
LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 2);

GTEST_PARAMETER_TEST(Multiply, LayerSemantics,
::testing::Values(semantic_multiply,
semantic_multiply_multi));
21 changes: 21 additions & 0 deletions test/unittest/models/unittest_models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,25 @@ static std::unique_ptr<NeuralNetwork> makeSubtractOperation() {
return nn;
}

static std::unique_ptr<NeuralNetwork> makeMultiplyOperation() {
std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());

auto outer_graph =
makeGraph({{"input", {"name=in", "input_shape=1:1:2"}},
{"fully_connected", {"name=fc", "unit=2", "input_layers=in"}},
{"multiply", {"name=multiply_layer", "input_layers=in,fc"}},
{"mse", {"name=loss", "input_layers=multiply_layer"}}});

for (auto &node : outer_graph) {
nn->addLayer(node);
}

nn->setProperty({"batch_size=1"});
nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate=0.1"}));

return nn;
}

GTEST_PARAMETER_TEST(
model, nntrainerModelTest,
::testing::ValuesIn({
Expand Down Expand Up @@ -984,6 +1003,8 @@ GTEST_PARAMETER_TEST(
mkModelTc_V2(makeAddOperation, "add_operation", ModelTestOption::ALL_V2),
mkModelTc_V2(makeSubtractOperation, "subtract_operation",
ModelTestOption::ALL_V2),
mkModelTc_V2(makeMultiplyOperation, "multiply_operation",
ModelTestOption::ALL_V2),
}),
[](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info)
-> const auto & { return std::get<1>(info.param); });
Expand Down
Loading