Skip to content

Commit

Permalink
[Tensor] Add broadcast support for operations
Browse files Browse the repository at this point in the history
This PR adds support for broadcasting to enable broadcasting in operation in the future.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <[email protected]>
  • Loading branch information
djeong20 authored and jijoongmoon committed Jan 19, 2024
1 parent ac68c68 commit c5ae995
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 0 deletions.
106 changes: 106 additions & 0 deletions nntrainer/tensor/tensor_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,37 @@
*/

#include <tensor_base.h>
#include <tensor_v2.h>

namespace nntrainer {

/**
* @struct External Loop Info for broadcasted info
* @brief External Loop Info for broadcasted iteration. Please refer to
* DISABLED_private_external_loop_n in unittest_nntrainer_tensor.
* @note This should better be implemented in iterator fashion before used
* extensively.
*/
struct TensorBase::BroadcastInfoV2 {

/**
* @brief Construct a new External Loop Info object
*
*/
BroadcastInfoV2() :
buffer_size(0),
buffer_axis(-1),
strides{0, 0, 0, 0},
tensor_type(nntrainer::TensorDim::TensorType()) {}

unsigned int buffer_size; /**< virtual size of the buffer */
int buffer_axis; /**< the smallest axis that should be looped.
-1 means no loop needed*/
std::array<unsigned int, TensorDim::MAXDIM>
strides; /**< modified strides for the loop */
nntrainer::TensorDim::TensorType tensor_type;
};

TensorBase::TensorBase(const TensorDim &d, bool alloc_now, Initializer init,
std::string name_) :
TensorBase(name_, d.getFormat()) {
Expand Down Expand Up @@ -119,4 +147,82 @@ TensorBase *TensorBase::getSharedDataTensor(const TensorDim dim_, size_t offset,
return ret;
}

TensorBase::BroadcastInfoV2
TensorBase::computeBroadcastInfo(const TensorV2 &m) const {
if (m.size() > this->size())
throw exception::not_supported("broadcasting *this is not supported");

const TensorDim m_dim = m.getDim();

BroadcastInfoV2 e;
e.tensor_type = getTensorType();

uint continuity[4] = {0, 1, 2, 3};
if (getFormat() == Tformat::NHWC) {
continuity[1] = 2;
continuity[2] = 3;
continuity[3] = 1;
}

/// checking if given Tensor's can be broadcasted
for (unsigned int i = 0; i < TensorDim::MAXDIM; ++i) {
if (dim.getTensorDim(continuity[i]) == m_dim.getTensorDim(continuity[i])) {
e.strides[i] = m.getStrides()[i];
continue;
}

/// If given dimension is 1, it could be reused, the stride remaining 0
/// Need to check if dim[i] == 1 && m_dim[i] == 1 first though
/// If so, strides should not change
if (m_dim.getTensorDim(continuity[i]) == 1) {
continue;
}

std::stringstream ss;
ss << "[computeBroadcastInfo] broadcasting only allowed for "
"dimension value of 1 \n"
<< "this: " << dim << "target: " << m_dim;
throw std::invalid_argument(ss.str().c_str());
}

/// calculate inner loop size
e.buffer_size = 1;
e.buffer_axis = -1;
e.strides[3] = m.getStrides()[3];

/// initiate buffer info with matching dimension strategy
for (int axis = 3; axis >= 0; --axis) {
if (dim.getTensorDim(continuity[axis]) !=
m_dim.getTensorDim(continuity[axis])) {
e.buffer_axis = axis;
break;
}

e.buffer_size *= dim.getTensorDim(continuity[axis]);
}

/// check strategy that uses consecutive ones
if (m_dim.getTensorDim(continuity[3]) == 1) {
unsigned int inner_loop_size = 1;
int axis;
for (axis = 3; axis >= 0; --axis) {
if (m_dim.getTensorDim(continuity[axis]) != 1) {
break;
}

inner_loop_size *= dim.getTensorDim(continuity[axis]);
}

/// if consecutive-one strategy has bigger chunk size, replace the
/// information
if (inner_loop_size > e.buffer_size) {
e.buffer_axis = axis;
e.buffer_size = inner_loop_size;
e.strides[3] = 0;
}
}

return e;
}

} // namespace nntrainer
21 changes: 21 additions & 0 deletions nntrainer/tensor/tensor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,17 @@ class TensorBase {
*/
void putData() const;

/**
* @brief return a copy of the Tensor Dim
* @retval TensorDim
*/
TensorDim getDim() const { return TensorDim(dim); }

/**
* @brief return Tensor Type
*/
TensorDim::TensorType getTensorType() const { return dim.getTensorType(); }

/**
* @brief Get initializer for the tensor
* @retval initializer of the tensor
Expand Down Expand Up @@ -371,6 +382,16 @@ class TensorBase {
* src_ptr is valid, this tensor will use the memory allocated by the src_ptr
*/
std::shared_ptr<SrcSharedTensorBase> src_tensor;

struct BroadcastInfoV2;

/**
* @brief compute Loop info for broadcasting and vectorization
*
* @param m target tensor to be calculated against.
* @return BroadcastInfo Loopinfo needed to run external loop
*/
BroadcastInfoV2 computeBroadcastInfo(const TensorV2 &m) const;
};

/**
Expand Down
6 changes: 6 additions & 0 deletions nntrainer/tensor/tensor_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ void TensorV2::print(std::ostream &out) const { itensor->print(out); }

void TensorV2::putData() const { itensor->putData(); }

TensorDim TensorV2::getDim() const { return itensor->getDim(); }

TensorDim::TensorType TensorV2::getTensorType() const {
return itensor->getTensorType();
};

Initializer TensorV2::getInitializer() const {
return itensor->getInitializer();
}
Expand Down
11 changes: 11 additions & 0 deletions nntrainer/tensor/tensor_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,17 @@ class TensorV2 {
*/
void putData() const;

/**
* @brief return a copy of the Tensor Dim
* @retval TensorDim
*/
TensorDim getDim() const;

/**
* @brief return Tensor Type
*/
TensorDim::TensorType getTensorType() const;

/**
* @brief Get initializer for the tensor
*
Expand Down

0 comments on commit c5ae995

Please sign in to comment.