Skip to content

Commit

Permalink
[Tensor] Update with scale factor support for quantized data types
Browse files Browse the repository at this point in the history
This pull request aims to enhance the functionality of the Tensor class by enabling it to handle scale factors specifically designed for quantized data types.
This patch ensures tensor class accurately represents and processes quantized data while maintaining original features.
This update will provide developers with more flexibility when working with quantized models

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <[email protected]>
  • Loading branch information
djeong20 authored and jijoongmoon committed Nov 4, 2024
1 parent b1a3c75 commit ad23950
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
2 changes: 2 additions & 0 deletions nntrainer/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,8 @@ size_t Tensor::height() const { return itensor->height(); }

size_t Tensor::width() const { return itensor->width(); }

size_t Tensor::scale_size() const { return itensor->scale_size(); }

void Tensor::mergeAxis(unsigned int axis1, unsigned int axis2) {
NNTR_THROW_IF(!getContiguous(), std::invalid_argument)
<< getName() << " is not contiguous, cannot merge axis";
Expand Down
22 changes: 22 additions & 0 deletions nntrainer/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,22 @@ class Tensor {
return (T *)itensor->getData(idx);
}

/**
* @brief return scale pointer of Tensor
* @retval template T pointer
*/
template <typename T = float> T *getScale() const {
return (T *)itensor->getScale();
}

/**
* @brief return scale pointer of Tensor
* @retval template T pointer
*/
template <typename T = float> T *getScale(size_t idx) const {
return (T *)itensor->getScale(idx);
}

/**
* @brief i data index
* @retval template T pointer (address of ith data)
Expand Down Expand Up @@ -1588,6 +1604,12 @@ class Tensor {
*/
size_t width() const;

/**
* @brief return Tensor scale factor size if exists
* @retval scale factor size
*/
size_t scale_size() const;

/**
* @brief Merge the given two axis for tensor at second axis inplace
*
Expand Down
27 changes: 26 additions & 1 deletion nntrainer/tensor/tensor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,24 @@ class TensorBase {
*/
virtual void *getData(size_t idx) const = 0;

/**
* @copydoc Tensor::getScale()
*/
virtual void *getScale() const {
throw std::invalid_argument(
"Tensor::getScale() is not supported in tensor data type " +
getStringDataType());
}

/**
* @copydoc Tensor::getScale(size_t idx)
*/
virtual void *getScale(size_t idx) const {
throw std::invalid_argument(
"Tensor::getScale() is not supported in tensor data type " +
getStringDataType());
}

/**
* @brief i data index
* @retval address of ith data
Expand Down Expand Up @@ -568,7 +586,7 @@ class TensorBase {
* @brief Get size of current tensor
* @retval unsigned int size of the current tensor
*/
size_t size() const { return dim.getDataLen(); }
virtual size_t size() const { return dim.getDataLen(); }

/**
* @brief Get if the tensor is empty
Expand Down Expand Up @@ -606,6 +624,13 @@ class TensorBase {
*/
size_t width() const { return dim.width(); }

/**
* @brief return Tensor scale factor size if exists
* @retval scale factor size
* @note Override for quantize tensor
*/
virtual size_t scale_size() const { return 0; }

/**
* @brief Merge the given two axis for tensor at second axis inplace
*
Expand Down

0 comments on commit ad23950

Please sign in to comment.