Skip to content

Commit

Permalink
[Tensor] Operational Improvements and Functionality Simplification
Browse files Browse the repository at this point in the history
This commit moves several operations implementations to each Tensor class for easier management.
This allows users to create a new data type Tensor without unnecessary modification to the Tensor class.

**Changes proposed in this PR:**
- static function Tensor::cat() uses each tensor's member function concat().
- Tensor::copy() logic is simplified by not differentiating by its data type.
- Tensor::copy_with_stride() uses an internal function to operate.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <[email protected]>
  • Loading branch information
djeong20 committed Jul 12, 2024
1 parent acb6d4c commit 469121c
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 64 deletions.
17 changes: 14 additions & 3 deletions nntrainer/tensor/float_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,18 @@ void FloatTensor::copyData(const Tensor &from) {
}
}

void FloatTensor::copy_with_stride(const Tensor &input, Tensor &output) {
for (unsigned int b = 0; b < output.batch(); ++b) {
for (unsigned int c = 0; c < output.channel(); ++c) {
for (unsigned int h = 0; h < output.height(); ++h) {
for (unsigned int w = 0; w < output.width(); ++w) {
output.setValue(b, c, h, w, input.getValue<float>(b, c, h, w));
}
}
}
}
}

std::vector<unsigned int> FloatTensor::argmax() const {
std::vector<unsigned int> result;
const float *data = (float *)getData();
Expand Down Expand Up @@ -1061,12 +1073,11 @@ std::vector<Tensor> FloatTensor::split(std::vector<size_t> sizes, int axis) {
return ret;
}

Tensor FloatTensor::cat(const std::vector<Tensor> &tensors, int axis) {
Tensor FloatTensor::concat(const std::vector<Tensor> &tensors, int axis) {
if (axis == -1) {
axis = 3;
}

Tensor ret;
auto ref_dim = tensors.front().getDim();
bool is_format_nchw = (ref_dim.getFormat() == Tformat::NCHW);
ref_dim.setTensorDim(axis, 1);
Expand Down Expand Up @@ -1106,7 +1117,7 @@ Tensor FloatTensor::cat(const std::vector<Tensor> &tensors, int axis) {
auto ret_dim = ref_dim;
ret_dim.setTensorDim(axis, axis_dim);

ret = Tensor(ret_dim);
Tensor ret = Tensor(ret_dim);

std::array<unsigned, 4> loc = {0, 0, 0, 0};
for (auto &t : tensors) {
Expand Down
9 changes: 8 additions & 1 deletion nntrainer/tensor/float_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class FloatTensor : public TensorBase {
/**
* @copydoc Tensor::cat(const std::vector<Tensor> &tensors, int axis)
*/
static Tensor cat(const std::vector<Tensor> &tensors, int axis);
Tensor concat(const std::vector<Tensor> &tensors, int axis) override;

/**
* @copydoc Tensor::copy(const Tensor &from)
Expand All @@ -375,6 +375,13 @@ class FloatTensor : public TensorBase {
*/
void copyData(const Tensor &from);

/**
* @brief Copy the Tensor
* @param[in] input Tensor to be copied
* @param[out] output output Tensor
*/
void copy_with_stride(const Tensor &input, Tensor &output) override;

/**
* @copydoc Tensor::argmax()
*/
Expand Down
21 changes: 16 additions & 5 deletions nntrainer/tensor/half_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -887,11 +887,10 @@ std::vector<Tensor> HalfTensor::split(std::vector<size_t> sizes, int axis) {
return ret;
}

Tensor HalfTensor::cat(const std::vector<Tensor> &tensors, int axis) {
Tensor HalfTensor::concat(const std::vector<Tensor> &tensors, int axis) {
if (axis == -1) {
axis = 3;
}
Tensor ret;
auto ref_dim = tensors.front().getDim();
bool is_format_nchw = (ref_dim.getFormat() == Tformat::NCHW);
ref_dim.setTensorDim(axis, 1);
Expand Down Expand Up @@ -931,7 +930,7 @@ Tensor HalfTensor::cat(const std::vector<Tensor> &tensors, int axis) {
auto ret_dim = ref_dim;
ret_dim.setTensorDim(axis, axis_dim);

ret = Tensor(ret_dim);
Tensor output = Tensor(ret_dim);

std::array<unsigned, 4> loc = {0, 0, 0, 0};
for (auto &t : tensors) {
Expand All @@ -950,7 +949,7 @@ Tensor HalfTensor::cat(const std::vector<Tensor> &tensors, int axis) {
}

for (size_t i = 0u, sz = t.size(); i < sz; ++i) {
iter_value(loc, start_loc, ret, tensor_dim_arr) = t.getValue<_FP16>(i);
iter_value(loc, start_loc, output, tensor_dim_arr) = t.getValue<_FP16>(i);
}

if (is_format_nchw) {
Expand All @@ -965,7 +964,7 @@ Tensor HalfTensor::cat(const std::vector<Tensor> &tensors, int axis) {
}
}
}
return ret;
return output;
}

void HalfTensor::print(std::ostream &out) const {
Expand Down Expand Up @@ -1060,6 +1059,18 @@ void HalfTensor::copyData(const Tensor &from) {
}
}

void HalfTensor::copy_with_stride(const Tensor &input, Tensor &output) {
for (unsigned int b = 0; b < output.batch(); ++b) {
for (unsigned int c = 0; c < output.channel(); ++c) {
for (unsigned int h = 0; h < output.height(); ++h) {
for (unsigned int w = 0; w < output.width(); ++w) {
output.setValue(b, c, h, w, input.getValue<_FP16>(b, c, h, w));
}
}
}
}
}

std::vector<unsigned int> HalfTensor::argmax() const {
std::vector<unsigned int> result;
const _FP16 *data = (_FP16 *)getData();
Expand Down
9 changes: 8 additions & 1 deletion nntrainer/tensor/half_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ class HalfTensor : public TensorBase {
/**
* @copydoc Tensor::cat(const std::vector<Tensor> &tensors, int axis)
*/
static Tensor cat(const std::vector<Tensor> &tensors, int axis);
Tensor concat(const std::vector<Tensor> &tensors, int axis) override;

/**
* @copydoc Tensor::copy(const Tensor &from)
Expand All @@ -365,6 +365,13 @@ class HalfTensor : public TensorBase {
*/
void copyData(const Tensor &from);

/**
* @brief Copy the Tensor
* @param[in] input Tensor to be copied
* @param[out] output output Tensor
*/
void copy_with_stride(const Tensor &input, Tensor &output) override;

/**
* @copydoc Tensor::argmax()
*/
Expand Down
66 changes: 12 additions & 54 deletions nntrainer/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,27 +818,19 @@ std::vector<Tensor> Tensor::split(std::vector<size_t> sizes, int axis) {
return itensor->split(sizes, axis);
}

Tensor Tensor::cat(const std::vector<Tensor> &tensors, int axis) {
Tensor Tensor::concat(const std::vector<Tensor> &tensors, int axis) {
NNTR_THROW_IF(!(-1 <= axis && axis < 4), std::invalid_argument)
<< "cannot split axis of axis: " << axis;

NNTR_THROW_IF(tensors.empty(), std::invalid_argument)
<< "given tensor vector is empty";

Tensor output;
Tdatatype dtype = tensors.front().getDim().getDataType();

if (dtype == Tdatatype::FP32) {
output = FloatTensor::cat(tensors, axis);
} else if (dtype == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
output = HalfTensor::cat(tensors, axis);
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
}
return itensor->concat(tensors, axis);
}

return output;
Tensor Tensor::cat(const std::vector<Tensor> &tensors, int axis) {
Tensor input = tensors[0];
return input.concat(tensors, axis);
}

void Tensor::print(std::ostream &out) const {
Expand Down Expand Up @@ -874,56 +866,22 @@ void Tensor::copy(const Tensor &from) {
// if tensor size and data type match, copy data
itensor->copy(from);
} else {
// replace with a new tensor that are the same with the given tensor
if (from.getDataType() == ml::train::TensorDim::DataType::FP32) {
Tensor t = Tensor(from.getDim(), from.getData<float>());
swap(t, *this);
} else if (from.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
Tensor t = Tensor(from.getDim(), from.getData<_FP16>());
swap(t, *this);
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
}
Tensor t = Tensor(from.getDim(), from.getData<char>());
swap(t, *this);
}
}

void Tensor::copyData(const Tensor &from) { itensor->copyData(from); }

void Tensor::copy_with_stride(const Tensor &from) {
if (itensor->getDim() == from.getDim()) {
// if the tensor dim matches, copy the data
for (unsigned int b = 0; b < batch(); ++b) {
for (unsigned int c = 0; c < channel(); ++c) {
for (unsigned int h = 0; h < height(); ++h) {
for (unsigned int w = 0; w < width(); ++w) {
setValue(b, c, h, w, from.getValue<float>(b, c, h, w));
}
}
}
}
// If the tensor dim matches, copy the data. This also applies to
// uncontigous tensor.
itensor->copy_with_stride(from, *this);
} else {
// replace with a new tensor that has the same data as the given tensor
Tensor t = Tensor(from.getDim(), true);
for (unsigned int b = 0; b < t.batch(); ++b) {
for (unsigned int c = 0; c < t.channel(); ++c) {
for (unsigned int h = 0; h < t.height(); ++h) {
for (unsigned int w = 0; w < t.width(); ++w) {
if (getDataType() == ml::train::TensorDim::DataType::FP32) {
t.setValue(b, c, h, w, from.getValue<float>(b, c, h, w));
} else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
/// @todo remove #ifdef ENABLE_FP16
#ifdef ENABLE_FP16
t.setValue(b, c, h, w, from.getValue<_FP16>(b, c, h, w));
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
}
}
}
}
}
itensor->copy_with_stride(from, t);
swap(t, *this);
}
}
Expand Down
9 changes: 9 additions & 0 deletions nntrainer/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,15 @@ class Tensor {
*/
std::vector<Tensor> split(std::vector<size_t> sizes, int axis = 0);

/**
* @brief concatenate tensors along axis
*
* @param tensors tensors to be concatenated to the first tensor
* @param axis axis
* @return Tensor concatenated tensor
*/
Tensor concat(const std::vector<Tensor> &tensors, int axis = 0);

/**
* @brief concatenate tensors along axis
*
Expand Down
12 changes: 12 additions & 0 deletions nntrainer/tensor/tensor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ class TensorBase {
*/
virtual std::vector<Tensor> split(std::vector<size_t> sizes, int axis) = 0;

/**
* @copydoc Tensor::concat(const std::vector<Tensor> &tensors, int axis)
*/
virtual Tensor concat(const std::vector<Tensor> &tensors, int axis) = 0;

/**
* @copydoc Tensor::print(std::ostream &out)
*/
Expand Down Expand Up @@ -431,6 +436,13 @@ class TensorBase {
*/
virtual void copyData(const Tensor &from) = 0;

/**
* @brief Copy the Tensor
* @param[in] input Tensor to be copied
* @param[out] output output Tensor
*/
virtual void copy_with_stride(const Tensor &input, Tensor &output) = 0;

/**
* @copydoc Tensor::argmax()
*/
Expand Down

0 comments on commit 469121c

Please sign in to comment.