Skip to content

Commit

Permalink
[TensorV2] multiply_strided() skeleton
Browse files Browse the repository at this point in the history
This pull request introduces a basic structure of tensor multiplication operations that support different strided inputs and outputs.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <[email protected]>
  • Loading branch information
djeong20 committed Jan 23, 2024
1 parent 07db2d5 commit ac2eb9f
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 0 deletions.
11 changes: 11 additions & 0 deletions nntrainer/tensor/float_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,17 @@ class FloatTensor : public TensorBase {
TensorV2 &apply(std::function<float(float)> f,
TensorV2 &output) const override;

/**
* @copydoc TensorV2::multiply_strided(TensorV2 const &m, TensorV2 &output,
* const float beta)
* @todo Need implementation and unit tests
*/
TensorV2 multiply_strided(TensorV2 const &m, TensorV2 &output,
const float beta) const override {
throw std::logic_error("multiply_strided is not implemented yet");
return output;
}

/**
* @copydoc TensorV2::multiply_i(float const &value)
* @todo Need implementation and unit tests
Expand Down
11 changes: 11 additions & 0 deletions nntrainer/tensor/half_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,17 @@ class HalfTensor : public TensorBase {
TensorV2 &apply(std::function<_FP16(_FP16)> f,
TensorV2 &output) const override;

/**
* @copydoc TensorV2::multiply_strided(TensorV2 const &m, TensorV2 &output,
* const float beta)
* @todo Need implementation and unit tests
*/
TensorV2 multiply_strided(TensorV2 const &m, TensorV2 &output,
const float beta) const override {
throw std::logic_error("multiply_strided is not implemented yet");
return output;
}

/**
* @copydoc TensorV2::multiply_i(float const &value)
* @todo Need implementation and unit tests
Expand Down
7 changes: 7 additions & 0 deletions nntrainer/tensor/tensor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ class TensorBase {
*/
virtual void initialize(Initializer init) = 0;

/**
* @copydoc TensorV2::multiply_strided(TensorV2 const &m, TensorV2 &output,
* const float beta)
*/
virtual TensorV2 multiply_strided(TensorV2 const &m, TensorV2 &output,
const float beta) const = 0;

/**
* @copydoc TensorV2::multiply_i(float const &value)
*/
Expand Down
22 changes: 22 additions & 0 deletions nntrainer/tensor/tensor_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,28 @@ TensorV2 &TensorV2::apply(std::function<TensorV2 &(TensorV2, TensorV2 &)> f,
return f(*this, output);
}

int TensorV2::multiply_i_strided(TensorV2 const &m, const float beta) {
try {
this->multiply_strided(m, *this, beta);
} catch (std::exception &err) {
ml_loge("%s %s", typeid(err).name(), err.what());
return ML_ERROR_INVALID_PARAMETER;
}

return ML_ERROR_NONE;
}

TensorV2 TensorV2::multiply_strided(TensorV2 const &m, const float beta) const {
TensorV2 t;
return this->multiply_strided(m, t, beta);
}

TensorV2 &TensorV2::multiply_strided(TensorV2 const &m, TensorV2 &output,
const float beta) const {
throw std::logic_error("multiply_strided is not implemented yet");
return output;
}

int TensorV2::multiply_i(float const &value) { return ML_ERROR_NONE; }

TensorV2 TensorV2::multiply(float const &value) const {
Expand Down
41 changes: 41 additions & 0 deletions nntrainer/tensor/tensor_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,47 @@ class TensorV2 {
TensorV2 &apply(std::function<TensorV2 &(TensorV2, TensorV2 &)> f,
TensorV2 &output) const;

/**
* @brief Multiply Tensor Elementwise
* @param[in] m Tensor to be multiplied
* @param[in] beta scalar to multiply output with and add
* @retval #ML_ERROR_NONE successful
*
* @note support different strided inputs and output
* @note does not support broadcasting
*
* @todo merge this to multiply_i
*/
int multiply_i_strided(TensorV2 const &m, const float beta = 0.0);

/**
* @brief Multiply Tensor Element by Element ( Not the MxM )
* @param[in] m Tensor to be multiplied
* @param[in] beta scalar to multiply output with and add
* @retval Calculated Tensor
*
* @note support different strided inputs and output
* @note does not support broadcasting
*
* @todo merge this to multiply
*/
TensorV2 multiply_strided(TensorV2 const &m, const float beta = 0.0) const;

/**
* @brief Multiply Tensor Element by Element ( Not the MxM )
* @param[in] m Tensor to be multiplied
* @param[out] output Tensor to store the result
* @param[in] beta scalar to multiply output with and add
* @retval Calculated Tensor
*
* @note support different strided inputs and output
* @note does not support broadcasting
*
* @todo merge this to multiply
*/
TensorV2 &multiply_strided(TensorV2 const &m, TensorV2 &output,
const float beta = 0.0) const;

/**
* @brief Multiply value element by element immediately
* @param[in] value multiplier
Expand Down

0 comments on commit ac2eb9f

Please sign in to comment.