Skip to content

Commit

Permalink
[TensorV2] Enable copying data from Tensor
Browse files Browse the repository at this point in the history
This PR enables deep copy functionalities of a contiguous tensor with the following functions. copy(), copyData(), and copy_with_strides().

The copy function completely copies the target tensor values regardless of the dimension of the input tensor. All elements and properties of the original tensor are copied to the new tensor. Therefore, if the copy function is used, a new tensor with the same size and shape as the original tensor is created.

On the other hand, the copyData function must match the size of the input and target tensors. This function only copies the data of the original tenor, so if the size or shape of the tensor is different, the copy may not be done properly.

Note that copy and copyData functions support copy data from multiple tensor data types while the copy_with_strides function only supports copying data from the same data type.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <[email protected]>
  • Loading branch information
djeong20 committed Jan 19, 2024
1 parent 25618e7 commit 4dcf6cb
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 4 deletions.
34 changes: 32 additions & 2 deletions nntrainer/tensor/float_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,37 @@ TensorV2 &FloatTensor::apply(std::function<float(float)> f,
return output;
}

void FloatTensor::copy(const TensorV2 &from) {
reshape(from.getDim());
copy(from.getData<float>());
}

void FloatTensor::copyData(const TensorV2 &from) {
if (!contiguous) {
throw std::runtime_error("Cannot copy non-contiguous tensor");
}

if (size() != from.size())
throw std::invalid_argument("Size of tensor to copy must match");

switch (from.getDataType()) {
case ml::train::TensorDim::DataType::FP32:
copy(from.getData<float>());
break;
case ml::train::TensorDim::DataType::FP16:
/// @todo remove #ifdef ENABLE_FP16
#ifdef ENABLE_FP16
scopy(size(), from.getData<_FP16>(), 1, (float *)getData(), 1);
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
break;
default:
throw std::invalid_argument("Error: Unsupported data type");
break;
}
}

void FloatTensor::print(std::ostream &out) const {
printInstance(out, this);
const float *data = (float *)getData();
Expand Down Expand Up @@ -342,10 +373,9 @@ void FloatTensor::print(std::ostream &out) const {
out.copyfmt(init);
}

/// @todo include getName()
void FloatTensor::copy(const void *buf) {
NNTR_THROW_IF(!contiguous, std::invalid_argument)
<< "Tensor is not contiguous, cannot copy.";
<< getName() << " is not contiguous, cannot copy.";

if (buf == getData()) {
return;
Expand Down
10 changes: 10 additions & 0 deletions nntrainer/tensor/float_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,16 @@ class FloatTensor : public TensorBase {
return output;
}

/**
* @copydoc TensorV2::copy(const TensorV2 &from)
*/
void copy(const TensorV2 &from);

/**
* @copydoc TensorV2::copyData(const TensorV2 &from)
*/
void copyData(const TensorV2 &from);

/**
* @copydoc TensorV2::print(std::ostream &out)
*/
Expand Down
29 changes: 27 additions & 2 deletions nntrainer/tensor/half_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,32 @@ TensorV2 &HalfTensor::apply(std::function<_FP16(_FP16)> f,
return output;
}

void HalfTensor::copy(const TensorV2 &from) {
reshape(from.getDim());
copy(from.getData<_FP16>());
}

void HalfTensor::copyData(const TensorV2 &from) {
if (!contiguous) {
throw std::runtime_error("Cannot copy non-contiguous tensor");
}

if (size() != from.size())
throw std::invalid_argument("Size of tensor to copy must match");

switch (from.getDataType()) {
case ml::train::TensorDim::DataType::FP32:
scopy(size(), from.getData<float>(), 1, (_FP16 *)getData(), 1);
break;
case ml::train::TensorDim::DataType::FP16:
copy(from.getData<_FP16>());
break;
default:
throw std::invalid_argument("Error: Unsupported data type");
break;
}
}

void HalfTensor::print(std::ostream &out) const {
printInstance(out, this);
const _FP16 *data = (_FP16 *)getData();
Expand Down Expand Up @@ -343,10 +369,9 @@ void HalfTensor::print(std::ostream &out) const {
out.copyfmt(init);
}

/// @todo include getName()
void HalfTensor::copy(const void *buf) {
NNTR_THROW_IF(!contiguous, std::invalid_argument)
<< "Tensor is not contiguous, cannot copy.";
<< getName() << " is not contiguous, cannot copy.";

if (buf == getData()) {
return;
Expand Down
10 changes: 10 additions & 0 deletions nntrainer/tensor/half_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,16 @@ class HalfTensor : public TensorBase {
return output;
}

/**
* @copydoc TensorV2::copy(const TensorV2 &from)
*/
void copy(const TensorV2 &from);

/**
* @copydoc TensorV2::copyData(const TensorV2 &from)
*/
void copyData(const TensorV2 &from);

/**
* @copydoc TensorV2::print(std::ostream &out)
*/
Expand Down
14 changes: 14 additions & 0 deletions nntrainer/tensor/tensor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,20 @@ class TensorBase {
}
#endif

/**
* @brief Copy the Tensor
* @param[in] from Tensor to be copied
*
* @note copy can reshape the tensor to match the shape
*/
virtual void copy(const TensorV2 &from) = 0;

/**
* @brief Copy the Tensor
* @param[in] from Tensor to be copied
*/
virtual void copyData(const TensorV2 &from) = 0;

/**
* @brief put data of Tensor
* @note It is only effective when memory_swap is used
Expand Down
57 changes: 57 additions & 0 deletions nntrainer/tensor/tensor_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,63 @@ void TensorV2::print(std::ostream &out) const { itensor->print(out); }

void TensorV2::putData() const { itensor->putData(); }

void TensorV2::copy(const TensorV2 &from) {
/// @todo enable copy to non-contiguous tensor
if (!itensor->getContiguous()) {
throw std::runtime_error("Cannot copy non-contiguous tensor");
}

if (from.size() != 0 && size() == from.size() &&
getDataType() == from.getDataType()) {
// if tensor size and data type match, copy data
itensor->copy(from);
} else {
// replace with a new tensor that are the same with the given tensor
if (from.getDataType() == ml::train::TensorDim::DataType::FP32) {
TensorV2 t = TensorV2(from.getDim(), from.getData<float>());
swap(t, *this);
} else if (from.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
TensorV2 t = TensorV2(from.getDim(), from.getData<_FP16>());
swap(t, *this);
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
}
}
}

void TensorV2::copyData(const TensorV2 &from) { itensor->copyData(from); }

void TensorV2::copy_with_stride(const TensorV2 &from) {
if (itensor->getDim() == from.getDim()) {
// if the tensor dim matches, copy the data
copy(from);
} else {
// replace with a new tensor that has the same data as the given tensor
TensorV2 t = TensorV2(from.getDim(), true);
for (unsigned int b = 0; b < t.batch(); ++b) {
for (unsigned int c = 0; c < t.channel(); ++c) {
for (unsigned int h = 0; h < t.height(); ++h) {
for (unsigned int w = 0; w < t.width(); ++w) {
if (getDataType() == ml::train::TensorDim::DataType::FP32) {
t.setValue(b, c, h, w, from.getValue<float>(b, c, h, w));
} else if (getDataType() == ml::train::TensorDim::DataType::FP16) {
/// @todo remove #ifdef ENABLE_FP16
#ifdef ENABLE_FP16
t.setValue(b, c, h, w, from.getValue<_FP16>(b, c, h, w));
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
}
}
}
}
}
swap(t, *this);
}
}

void TensorV2::reshape(const TensorDim &d) { itensor->reshape(d); }

TensorDim TensorV2::getDim() const { return itensor->getDim(); }
Expand Down
23 changes: 23 additions & 0 deletions nntrainer/tensor/tensor_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,29 @@ class TensorV2 {
*/
void putData() const;

/**
* @brief Copy the Tensor
* @param[in] from Tensor to be copied
*
* @note copy can reshape the tensor to match the shape
* @note support copying data from multiple data type
*/
void copy(const TensorV2 &from);

/**
* @brief Copy the Tensor
* @param[in] from Tensor to be copied
* @note support copying data from multiple data type
*/
void copyData(const TensorV2 &from);

/**
* @brief Copy the Tensor
* @param[in] from Tensor to be copied
* @note only support copying data from tensor with the same data type
*/
void copy_with_stride(const TensorV2 &from);

/**
* @brief set Tensor Dim
* @param[in] d TensorDim
Expand Down

0 comments on commit 4dcf6cb

Please sign in to comment.