Skip to content

Commit

Permalink
[TensorV2] Refactoring TensorBase pointer to shared_ptr
Browse files Browse the repository at this point in the history
This PR proposes refactoring the TensorV2 class to use a shared_ptr instead of a raw pointer for managing its TensorBase object.
By adopting this change, we can improve the safety and reliability of our code and reduce the likelihood of memory leaks and other issues related to manual memory management.

**Changes proposed in this PR:**
- Replace the TensorBase pointer in the Tensor class with a shared_ptr.
- Update any relevant code to use the shared_ptr instead of the raw pointer.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <[email protected]>
  • Loading branch information
djeong20 committed Jan 26, 2024
1 parent 4a7f3c2 commit c0316f6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
40 changes: 26 additions & 14 deletions nntrainer/tensor/tensor_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ TensorV2::TensorV2(std::string name_, Tformat fm, Tdatatype d_type) {
itensor = nullptr;

if (d_type == Tdatatype::FP32) {
itensor = new FloatTensor(name_, fm);
itensor = std::shared_ptr<FloatTensor>(new FloatTensor(name_, fm),
std::default_delete<FloatTensor>());
} else if (d_type == Tdatatype::FP16) {
#ifdef ENABLE_FP16
itensor = new HalfTensor(name_, fm);
itensor = std::shared_ptr<HalfTensor>(new HalfTensor(name_, fm),
std::default_delete<HalfTensor>());
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
Expand All @@ -42,10 +44,14 @@ TensorV2::TensorV2(const TensorDim &d, bool alloc_now, Initializer init,
itensor = nullptr;

if (d.getDataType() == Tdatatype::FP32) {
itensor = new FloatTensor(d, alloc_now, init, name);
itensor =
std::shared_ptr<FloatTensor>(new FloatTensor(d, alloc_now, init, name),
std::default_delete<FloatTensor>());
} else if (d.getDataType() == Tdatatype::FP16) {
#ifdef ENABLE_FP16
itensor = new HalfTensor(d, alloc_now, init, name);
itensor =
std::shared_ptr<HalfTensor>(new HalfTensor(d, alloc_now, init, name),
std::default_delete<HalfTensor>());
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
Expand All @@ -61,10 +67,12 @@ TensorV2::TensorV2(const TensorDim &d, const void *buf) {
itensor = nullptr;

if (d.getDataType() == Tdatatype::FP32) {
itensor = new FloatTensor(d, buf);
itensor = std::shared_ptr<FloatTensor>(new FloatTensor(d, buf),
std::default_delete<FloatTensor>());
} else if (d.getDataType() == Tdatatype::FP16) {
#ifdef ENABLE_FP16
itensor = new HalfTensor(d, buf);
itensor = std::shared_ptr<HalfTensor>(new HalfTensor(d, buf),
std::default_delete<HalfTensor>());
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
Expand All @@ -79,14 +87,16 @@ TensorV2::TensorV2(const TensorDim &d, const void *buf) {
TensorV2::TensorV2(
std::vector<std::vector<std::vector<std::vector<float>>>> const &d,
ml::train::TensorDim::TensorType t_type) {
itensor = new FloatTensor(d, t_type.format);
itensor = std::shared_ptr<FloatTensor>(new FloatTensor(d, t_type.format),
std::default_delete<FloatTensor>());
}

#ifdef ENABLE_FP16
TensorV2::TensorV2(
std::vector<std::vector<std::vector<std::vector<_FP16>>>> const &d,
ml::train::TensorDim::TensorType t_type) {
itensor = new HalfTensor(d, t_type.format);
itensor = std::shared_ptr<HalfTensor>(new HalfTensor(d, t_type.format),
std::default_delete<HalfTensor>());
}
#endif

Expand All @@ -95,12 +105,12 @@ bool TensorV2::operator==(const TensorV2 &rhs) const {
if (*itensor == *rhs.itensor) {
/// compares tensor data
if (getDataType() == Tdatatype::FP32) {
return *dynamic_cast<FloatTensor *>(itensor) ==
*dynamic_cast<FloatTensor *>(rhs.itensor);
return *std::dynamic_pointer_cast<FloatTensor>(itensor) ==
*std::dynamic_pointer_cast<FloatTensor>(rhs.itensor);
} else if (getDataType() == Tdatatype::FP16) {
#ifdef ENABLE_FP16
return *dynamic_cast<HalfTensor *>(itensor) ==
*dynamic_cast<HalfTensor *>(rhs.itensor);
return *std::dynamic_pointer_cast<HalfTensor>(itensor) ==
*std::dynamic_pointer_cast<HalfTensor>(rhs.itensor);
#else
throw std::invalid_argument(
"Error: HalfTensor cannot be created or used when FP16 is not enabled. "
Expand Down Expand Up @@ -305,14 +315,16 @@ size_t TensorV2::width() const { return itensor->width(); }

void TensorV2::createSharedDataTensor(const TensorV2 &src, TensorV2 &dest,
size_t offset) const {
itensor->createSharedDataTensor(src.itensor, dest.itensor, offset);
itensor->createSharedDataTensor(src.itensor.get(), dest.itensor.get(),
offset);
}

TensorV2 TensorV2::getSharedDataTensor(const TensorDim dim_, size_t offset,
bool reset_stride,
const std::string &name_) const {
TensorV2 ret = *this;
ret.itensor = itensor->getSharedDataTensor(dim_, offset, reset_stride, name_);
ret.itensor = std::shared_ptr<TensorBase>(
itensor->getSharedDataTensor(dim_, offset, reset_stride, name_));
return ret;
}

Expand Down
2 changes: 1 addition & 1 deletion nntrainer/tensor/tensor_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ class TensorV2 {
}

private:
TensorBase *itensor;
std::shared_ptr<TensorBase> itensor;
};

} // namespace nntrainer
Expand Down

0 comments on commit c0316f6

Please sign in to comment.