Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensor] Update with scale factor support for quantized data types #2779

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nntrainer/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,8 @@ size_t Tensor::height() const { return itensor->height(); }

size_t Tensor::width() const { return itensor->width(); }

size_t Tensor::scale_size() const { return itensor->scale_size(); }

void Tensor::mergeAxis(unsigned int axis1, unsigned int axis2) {
NNTR_THROW_IF(!getContiguous(), std::invalid_argument)
<< getName() << " is not contiguous, cannot merge axis";
Expand Down
22 changes: 22 additions & 0 deletions nntrainer/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,22 @@ class Tensor {
return (T *)itensor->getData(idx);
}

/**
* @brief return scale pointer of Tensor
* @retval template T pointer
*/
template <typename T = float> T *getScale() const {
return (T *)itensor->getScale();
}

/**
* @brief return scale pointer of Tensor
* @retval template T pointer
*/
template <typename T = float> T *getScale(size_t idx) const {
return (T *)itensor->getScale(idx);
}

/**
* @brief i data index
* @retval template T pointer (address of ith data)
Expand Down Expand Up @@ -1579,6 +1595,12 @@ class Tensor {
*/
size_t width() const;

/**
* @brief return Tensor scale factor size if exists
* @retval scale factor size
*/
size_t scale_size() const;

/**
* @brief Merge the given two axis for tensor at second axis inplace
*
Expand Down
27 changes: 26 additions & 1 deletion nntrainer/tensor/tensor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,24 @@ class TensorBase {
*/
virtual void *getData(size_t idx) const = 0;

/**
* @copydoc Tensor::getScale()
*/
virtual void *getScale() const {
throw std::invalid_argument(
"Tensor::getScale() is not supported in tensor data type " +
getStringDataType());
}

/**
* @copydoc Tensor::getScale(size_t idx)
*/
virtual void *getScale(size_t idx) const {
throw std::invalid_argument(
"Tensor::getScale() is not supported in tensor data type " +
getStringDataType());
}

/**
* @brief i data index
* @retval address of ith data
Expand Down Expand Up @@ -568,7 +586,7 @@ class TensorBase {
* @brief Get size of current tensor
* @retval unsigned int size of the current tensor
*/
size_t size() const { return dim.getDataLen(); }
virtual size_t size() const { return dim.getDataLen(); }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size() is now a virtual function since it will be overridden by a quantized tensor that utilizes data packing.


/**
* @brief Get if the tensor is empty
Expand Down Expand Up @@ -606,6 +624,13 @@ class TensorBase {
*/
size_t width() const { return dim.width(); }

/**
* @brief return Tensor scale factor size if exists
* @retval scale factor size
* @note Override for quantize tensor
*/
virtual size_t scale_size() const { return 0; }

/**
* @brief Merge the given two axis for tensor at second axis inplace
*
Expand Down
Loading