Skip to content

Commit

Permalink
[Bugfix] Read & save quantized tensor data.
Browse files Browse the repository at this point in the history
This PR resolves an issue that quantized tensors incorrectly read and save their data.
Note that scale factors are assumed to be in full-precision, and half-precision support will be later introduced.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <[email protected]>
  • Loading branch information
djeong20 authored and jijoongmoon committed Dec 10, 2024
1 parent 467193a commit e55b7ed
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
12 changes: 8 additions & 4 deletions nntrainer/tensor/bcq_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,15 @@ std::vector<unsigned int> BCQTensor::argmax() const {
return result;
}

// void BCQTensor::save_quant_bit(std::ostream &file) { return; }
void BCQTensor::save_quantization_info(std::ostream &file) {
checkedWrite(file, (char *)&quantized_bit_size, sizeof(uint16_t),
"[BCQTensor::save] failed to write quantization information");
}

// void BCQTensor::read_quant_bit(std::ifstream &file) {
// file.read((char *)&quantized_bit_size, sizeof(uint16_t));
// }
void BCQTensor::read_quantization_info(std::ifstream &file) {
checkedRead(file, (char *)&quantized_bit_size, sizeof(uint16_t),
"[BCQTensor::read] failed to read quantization information");
}

size_t BCQTensor::size() const {
return quantized_bit_size * dim.height() * ((dim.width() + 31) / 32);
Expand Down
8 changes: 4 additions & 4 deletions nntrainer/tensor/bcq_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,14 @@ class BCQTensor : public TensorBase {
std::vector<unsigned int> argmax() const override;

/**
* @copydoc TensorBase::save_quant_bit(std::ostream &file)
* @copydoc TensorBase::save_quantization_info(std::ostream &file)
*/
// void save_quant_bit(std::ostream &file) override;
void save_quantization_info(std::ostream &file) override;

/**
* @copydoc TensorBase::read_quant_bit(std::ifstream &file)
* @copydoc TensorBase::read_quantization_info(std::ifstream &file)
*/
// void read_quant_bit(std::ifstream &file) override;
void read_quantization_info(std::ifstream &file) override;

/**
* @copydoc TensorBase::size()
Expand Down
12 changes: 8 additions & 4 deletions nntrainer/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1094,9 +1094,11 @@ void Tensor::save(std::ostream &file) {
/// @note Save quantization information which only works on Quantized Tensor
itensor->save_quantization_info(file);

std::streamsize sz = static_cast<std::streamsize>(bytes() + scale_size());
/// @note Scale factors are temporary fixed to float for now
std::streamsize sz =
static_cast<std::streamsize>(bytes() + scale_size() * sizeof(float));
NNTR_THROW_IF(sz < 0, std::invalid_argument)
<< "save size: " << bytes() + scale_size()
<< "save size: " << bytes() + scale_size() * sizeof(float)
<< " is too big. It cannot be represented by std::streamsize";

checkedWrite(file, getData<char>(), sz, "[Tensor::save] operation failed");
Expand All @@ -1110,10 +1112,12 @@ void Tensor::read(std::ifstream &file) {
/// @note Read quantization information which only works on Quantized Tensor
itensor->read_quantization_info(file);

std::streamsize sz = static_cast<std::streamsize>(bytes() + scale_size());
/// @note Scale factors are temporary fixed to float for now
std::streamsize sz =
static_cast<std::streamsize>(bytes() + scale_size() * sizeof(float));

NNTR_THROW_IF(sz < 0, std::invalid_argument)
<< "read size: " << bytes() + scale_size()
<< "read size: " << bytes() + scale_size() * sizeof(float)
<< " is too big. It cannot be represented by std::streamsize";

checkedRead(file, getData<char>(), sz, "[Tensor::read] operation failed");
Expand Down

0 comments on commit e55b7ed

Please sign in to comment.