Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorV2] Refactoring TensorBase pointer to shared_ptr #2428

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions nntrainer/tensor/tensor_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ TensorV2::TensorV2(std::string name_, Tformat fm, Tdatatype d_type) {
itensor = nullptr;

if (d_type == Tdatatype::FP32) {
itensor = new FloatTensor(name_, fm);
itensor = std::shared_ptr<FloatTensor>(new FloatTensor(name_, fm),
std::default_delete<FloatTensor>());
} else if (d_type == Tdatatype::FP16) {
#ifdef ENABLE_FP16
itensor = new HalfTensor(name_, fm);
itensor = std::shared_ptr<HalfTensor>(new HalfTensor(name_, fm),
std::default_delete<HalfTensor>());
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
Expand All @@ -42,10 +44,14 @@ TensorV2::TensorV2(const TensorDim &d, bool alloc_now, Initializer init,
itensor = nullptr;

if (d.getDataType() == Tdatatype::FP32) {
itensor = new FloatTensor(d, alloc_now, init, name);
itensor =
std::shared_ptr<FloatTensor>(new FloatTensor(d, alloc_now, init, name),
std::default_delete<FloatTensor>());
} else if (d.getDataType() == Tdatatype::FP16) {
#ifdef ENABLE_FP16
itensor = new HalfTensor(d, alloc_now, init, name);
itensor =
std::shared_ptr<HalfTensor>(new HalfTensor(d, alloc_now, init, name),
std::default_delete<HalfTensor>());
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
Expand All @@ -61,10 +67,12 @@ TensorV2::TensorV2(const TensorDim &d, const void *buf) {
itensor = nullptr;

if (d.getDataType() == Tdatatype::FP32) {
itensor = new FloatTensor(d, buf);
itensor = std::shared_ptr<FloatTensor>(new FloatTensor(d, buf),
std::default_delete<FloatTensor>());
} else if (d.getDataType() == Tdatatype::FP16) {
#ifdef ENABLE_FP16
itensor = new HalfTensor(d, buf);
itensor = std::shared_ptr<HalfTensor>(new HalfTensor(d, buf),
std::default_delete<HalfTensor>());
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
Expand All @@ -79,14 +87,16 @@ TensorV2::TensorV2(const TensorDim &d, const void *buf) {
TensorV2::TensorV2(
std::vector<std::vector<std::vector<std::vector<float>>>> const &d,
ml::train::TensorDim::TensorType t_type) {
itensor = new FloatTensor(d, t_type.format);
itensor = std::shared_ptr<FloatTensor>(new FloatTensor(d, t_type.format),
std::default_delete<FloatTensor>());
}

#ifdef ENABLE_FP16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be removed according to our offline discussion.

TensorV2::TensorV2(
std::vector<std::vector<std::vector<std::vector<_FP16>>>> const &d,
ml::train::TensorDim::TensorType t_type) {
itensor = new HalfTensor(d, t_type.format);
itensor = std::shared_ptr<HalfTensor>(new HalfTensor(d, t_type.format),
std::default_delete<HalfTensor>());
}
#endif

Expand All @@ -95,12 +105,12 @@ bool TensorV2::operator==(const TensorV2 &rhs) const {
if (*itensor == *rhs.itensor) {
/// compares tensor data
if (getDataType() == Tdatatype::FP32) {
return *dynamic_cast<FloatTensor *>(itensor) ==
*dynamic_cast<FloatTensor *>(rhs.itensor);
return *std::dynamic_pointer_cast<FloatTensor>(itensor) ==
*std::dynamic_pointer_cast<FloatTensor>(rhs.itensor);
} else if (getDataType() == Tdatatype::FP16) {
#ifdef ENABLE_FP16
return *dynamic_cast<HalfTensor *>(itensor) ==
*dynamic_cast<HalfTensor *>(rhs.itensor);
return *std::dynamic_pointer_cast<HalfTensor>(itensor) ==
*std::dynamic_pointer_cast<HalfTensor>(rhs.itensor);
#else
throw std::invalid_argument(
"Error: HalfTensor cannot be created or used when FP16 is not enabled. "
Expand Down Expand Up @@ -305,14 +315,16 @@ size_t TensorV2::width() const { return itensor->width(); }

void TensorV2::createSharedDataTensor(const TensorV2 &src, TensorV2 &dest,
size_t offset) const {
itensor->createSharedDataTensor(src.itensor, dest.itensor, offset);
itensor->createSharedDataTensor(src.itensor.get(), dest.itensor.get(),
offset);
}

TensorV2 TensorV2::getSharedDataTensor(const TensorDim dim_, size_t offset,
bool reset_stride,
const std::string &name_) const {
TensorV2 ret = *this;
ret.itensor = itensor->getSharedDataTensor(dim_, offset, reset_stride, name_);
ret.itensor = std::shared_ptr<TensorBase>(
itensor->getSharedDataTensor(dim_, offset, reset_stride, name_));
return ret;
}

Expand Down
2 changes: 1 addition & 1 deletion nntrainer/tensor/tensor_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ class TensorV2 {
}

private:
TensorBase *itensor;
std::shared_ptr<TensorBase> itensor;
};

} // namespace nntrainer
Expand Down
Loading