Skip to content

Commit

Permalink
[TensorV2] Multiplication operation skeleton
Browse files Browse the repository at this point in the history
This pull request adds a basic implementation of tensor multiplication operations to our codebase.
The new functionality allows users to perform multiplication of tensors by simply calling a function.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <[email protected]>
  • Loading branch information
djeong20 authored and jijoongmoon committed Jan 19, 2024
1 parent cd11a54 commit ac68c68
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 0 deletions.
24 changes: 24 additions & 0 deletions nntrainer/tensor/float_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,30 @@ class FloatTensor : public TensorBase {
TensorV2 &apply(std::function<float(float)> f,
TensorV2 &output) const override;

/**
* @copydoc TensorV2::multiply_i(float const &value)
* @todo Need implementation and unit tests
*/
int multiply_i(float const &value) override { return 0; }

/**
* @copydoc TensorV2::multiply(float const &value, TensorV2 &out)
* @todo Need implementation and unit tests
*/
TensorV2 &multiply(float const &value, TensorV2 &out) const override {
return out;
}

/**
* @copydoc TensorV2::multiply(TensorV2 const &m, TensorV2 &output, const
* float beta = 0.0)
* @todo Need implementation and unit tests
*/
TensorV2 &multiply(TensorV2 const &m, TensorV2 &output,
const float beta = 0.0) const override {
return output;
}

/**
* @copydoc TensorV2::print(std::ostream &out)
*/
Expand Down
24 changes: 24 additions & 0 deletions nntrainer/tensor/half_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,30 @@ class HalfTensor : public TensorBase {
TensorV2 &apply(std::function<_FP16(_FP16)> f,
TensorV2 &output) const override;

/**
* @copydoc TensorV2::multiply_i(float const &value)
* @todo Need implementation and unit tests
*/
int multiply_i(float const &value) override { return 0; }

/**
* @copydoc TensorV2::multiply(float const &value, TensorV2 &out)
* @todo Need implementation and unit tests
*/
TensorV2 &multiply(float const &value, TensorV2 &out) const override {
return out;
}

/**
* @copydoc TensorV2::multiply(TensorV2 const &m, TensorV2 &output, const
* float beta = 0.0)
* @todo Need implementation and unit tests
*/
TensorV2 &multiply(TensorV2 const &m, TensorV2 &output,
const float beta = 0.0) const override {
return output;
}

/**
* @copydoc TensorV2::print(std::ostream &out)
*/
Expand Down
17 changes: 17 additions & 0 deletions nntrainer/tensor/tensor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,23 @@ class TensorBase {
*/
virtual void initialize(Initializer init) = 0;

/**
* @copydoc TensorV2::multiply_i(float const &value)
*/
virtual int multiply_i(float const &value) = 0;

/**
* @copydoc TensorV2::multiply(float const &value, TensorV2 &out)
*/
virtual TensorV2 &multiply(float const &value, TensorV2 &out) const = 0;

/**
* @copydoc TensorV2::multiply(TensorV2 const &m, TensorV2 &output, const
* float beta = 0.0)
*/
virtual TensorV2 &multiply(TensorV2 const &m, TensorV2 &output,
const float beta = 0.0) const = 0;

/**
* @copydoc TensorV2::print(std::ostream &out)
*/
Expand Down
25 changes: 25 additions & 0 deletions nntrainer/tensor/tensor_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,31 @@ TensorV2 &TensorV2::apply(std::function<TensorV2 &(TensorV2, TensorV2 &)> f,
return f(*this, output);
}

int TensorV2::multiply_i(float const &value) { return ML_ERROR_NONE; }

TensorV2 TensorV2::multiply(float const &value) const {
TensorV2 t;
return multiply(value, t);
}

TensorV2 &TensorV2::multiply(float const &value, TensorV2 &out) const {
return out;
}

int TensorV2::multiply_i(TensorV2 const &m, const float beta) {
return ML_ERROR_NONE;
}

TensorV2 TensorV2::multiply(TensorV2 const &m, const float beta) const {
TensorV2 t("", this->getFormat());
return multiply(m, t, beta);
}

TensorV2 &TensorV2::multiply(TensorV2 const &m, TensorV2 &output,
const float beta) const {
return output;
}

void TensorV2::print(std::ostream &out) const { itensor->print(out); }

void TensorV2::putData() const { itensor->putData(); }
Expand Down
50 changes: 50 additions & 0 deletions nntrainer/tensor/tensor_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <cstddef>

#include <nntrainer_log.h>
#include <tensor_base.h>

namespace nntrainer {
Expand Down Expand Up @@ -467,6 +468,55 @@ class TensorV2 {
TensorV2 &apply(std::function<TensorV2 &(TensorV2, TensorV2 &)> f,
TensorV2 &output) const;

/**
* @brief Multiply value element by element immediately
* @param[in] value multiplier
* @retval #ML_ERROR_INVALID_PARAMETER Tensor dimension is not right
* @retval #ML_ERROR_NONE Successful
*/
int multiply_i(float const &value);

/**
* @brief Multiply value element by element
* @param[in] value multiplier
* @retval Calculated Tensor
*/
TensorV2 multiply(float const &value) const;

/**
* @brief multiply value element by element
* @param[in] value multiplier
* @param[out] out out tensor to store the result
* @retval Calculated Tensor
*/
TensorV2 &multiply(float const &value, TensorV2 &out) const;

/**
* @brief Multiply Tensor Elementwise
* @param[in] m Tensor to be multiplied
* @param[in] beta scalar to multiply output with and add
* @retval #ML_ERROR_NONE successful
*/
int multiply_i(TensorV2 const &m, const float beta = 0.0);

/**
* @brief Multiply Tensor Element by Element ( Not the MxM )
* @param[in] m Tensor to be multiplied
* @param[in] beta scalar to multiply output with and add
* @retval Calculated Tensor
*/
TensorV2 multiply(TensorV2 const &m, const float beta = 0.0) const;

/**
* @brief Multiply Tensor Element by Element ( Not the MxM )
* @param[in] m Tensor to be multiplied
* @param[out] output Tensor to store the result
* @param[in] beta scalar to multiply output with and add
* @retval Calculated Tensor
*/
TensorV2 &multiply(TensorV2 const &m, TensorV2 &output,
const float beta = 0.0) const;

/**
* @brief Print element
* @param[in] out out stream
Expand Down

0 comments on commit ac68c68

Please sign in to comment.