From 44ee96b0d5ec083c530a6c3e6b82e3d7e493b312 Mon Sep 17 00:00:00 2001 From: Donghyeon Jeong Date: Mon, 9 Dec 2024 09:42:30 +0900 Subject: [PATCH] [Tensor] Binary-code-based quantized tensor This PR adds a Tensor class that supports the Binary-code-based quantization type. Binary-code-based quantization (BCQ) allows extremely low bit quantization and BCQTensor aims to store and manage such quantized data. Note that the BiQGEMM is utilized for quantized inference and the BiQGEMM library is needed to use BCQTensor fully. **Self-evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghyeon Jeong --- api/ccapi/include/tensor_dim.h | 1 + meson.build | 11 + meson_options.txt | 1 + nntrainer/tensor/bcq_tensor.cpp | 365 ++++++++++++++++++++++++++++++++ nntrainer/tensor/bcq_tensor.h | 280 ++++++++++++++++++++++++ nntrainer/tensor/meson.build | 6 + nntrainer/tensor/tensor.cpp | 51 +++++ nntrainer/tensor/tensor_dim.cpp | 4 + 8 files changed, 719 insertions(+) create mode 100644 nntrainer/tensor/bcq_tensor.cpp create mode 100644 nntrainer/tensor/bcq_tensor.h diff --git a/api/ccapi/include/tensor_dim.h b/api/ccapi/include/tensor_dim.h index 19e0a9d4ca..10a1aab3f9 100644 --- a/api/ccapi/include/tensor_dim.h +++ b/api/ccapi/include/tensor_dim.h @@ -54,6 +54,7 @@ class TensorDim { enum class DataType { QINT4, /** quantized int 4*/ QINT8, /** quantized int 8*/ + BCQ, /** binary-code-based quantized*/ UINT8, /** unsigned int 8 bit */ UINT16, /** unsigned int 16 bit */ UINT32, /** unsigned int 32 bit */ diff --git a/meson.build b/meson.build index 47f493650e..815a47f576 100644 --- a/meson.build +++ b/meson.build @@ -143,6 +143,17 @@ if get_option('opencl-kernel-path') != '' message ('OpenCL kernel path set to: @0@'.format(get_option('opencl-kernel-path'))) extra_defines += '-DOPENCL_KERNEL_PATH=@0@'.format(get_option('opencl-kernel-path')) endif + +if get_option('enable-biqgemm') + # check if BiQGEMM directory exist. otherwise, throw an error + fs = import('fs') + if fs.is_dir('../BiQGEMM') + extra_defines += '-DENABLE_BIQGEMM=1' + biqgemm_inc = include_directories('../BiQGEMM') + else + error ('BiQGEMM cannot be enabled without BiQGEMM library.') + endif +endif foreach extra_arg : warning_flags if cc.has_argument (extra_arg) diff --git a/meson_options.txt b/meson_options.txt index 0095ba8a6a..316d8f2e1f 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -43,6 +43,7 @@ option('enable-openmp', type: 'boolean', value: true) option('enable-neon', type: 'boolean', value: false) option('enable-avx', type: 'boolean', value: true) option('enable-opencl', type: 'boolean', value: false) +option('enable-biqgemm', type: 'boolean', value: false) # ml-api dependency (to enable, install capi-inference from github.com/nnstreamer/api ) # To inter-operate with nnstreamer and ML-API packages, you need to enable this. diff --git a/nntrainer/tensor/bcq_tensor.cpp b/nntrainer/tensor/bcq_tensor.cpp new file mode 100644 index 0000000000..c350f2c0f9 --- /dev/null +++ b/nntrainer/tensor/bcq_tensor.cpp @@ -0,0 +1,365 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * @file bcq_tensor.cpp + * @date 06 December 2024 + * @brief This is BCQTensor class for binary-code-based quantization + * @see https://github.com/nnstreamer/nntrainer + * @author Donghyeon Jeong + * @bug No known bugs except for NYI items + */ + +#include +#include + +#include +#include +#include +#include + +#include "BiQGEMM.h" + +namespace nntrainer { + +BCQTensor::BCQTensor(std::string name_, Tformat fm) : + TensorBase(name_, fm, Tdatatype::BCQ) {} + +BCQTensor::BCQTensor(const TensorDim &d, bool alloc_now, Initializer init, + std::string name) : + TensorBase(d, alloc_now, init, name) { + if (alloc_now) + allocate(); +} + +BCQTensor::BCQTensor(const TensorDim &d, const void *buf) : BCQTensor(d, true) { + if (d.getDataLen() != 0) { + if (buf != nullptr) + copy(buf); + } +} + +bool BCQTensor::operator==(const BCQTensor &rhs) const { + const uint32_t *_data = (uint32_t *)getData(); + const uint32_t *_rdata = (uint32_t *)rhs.getData(); + + for (size_t i = 0; i < size() + scale_size(); ++i) { + if (_data[i] - _rdata[i]) + return false; + } + + return true; +} + +void BCQTensor::allocate() { + if (empty() || data) + return; + + if (src_tensor) { + /// allocate data based on the source tensor + allocateSrcTensor(); + /** as this memory is shared, do NOT initialize */ + } else { + /// allocate new memory for the tensor data + MemoryData *mem_data; + + mem_data = new MemoryData((void *)(new uint32_t[size() + scale_size()]{})); + data = std::shared_ptr(mem_data, [](auto *mem_data) { + delete[] mem_data->template getAddr(); + delete mem_data; + }); + + offset = 0; + initialize(); + } +} + +void BCQTensor::deallocate() { + data = nullptr; + offset = 0; +} + +void *BCQTensor::getData() const { + if (!data) + return nullptr; + + data->validate(); + return data->getAddr() + offset; +} + +void *BCQTensor::getData(size_t idx) const { + NNTR_THROW_IF(idx > dim.getDataLen(), std::invalid_argument) + << "Tensor::getData() index is not valid"; + + if (!data) + return nullptr; + + data->validate(); + return data->getAddr() + offset + (idx / 32); +} + +void *BCQTensor::getScale() const { + if (!data) + return nullptr; + + data->validate(); + return ((uint32_t *)getData()) + size(); +} + +void *BCQTensor::getScale(size_t idx) const { + NNTR_THROW_IF(idx > scale_size(), std::invalid_argument) + << "Tensor::getScale() index is not valid"; + + if (!data) + return nullptr; + + data->validate(); + return ((uint32_t *)getScale()) + idx; +} + +void *BCQTensor::getAddress(unsigned int i) { + size_t index = getIndex(batch(), channel(), height(), width() / 32); + if (i > index) { + return nullptr; + } + return &((uint32_t *)getData())[i]; +} + +const void *BCQTensor::getAddress(unsigned int i) const { + size_t index = getIndex(batch(), channel(), height(), width() / 32); + if (i > index) { + return nullptr; + } + return &((uint32_t *)getData())[i]; +} + +const uint32_t &BCQTensor::getValue(unsigned int i) const { + return ((uint32_t *)getData())[i]; +} + +uint32_t &BCQTensor::getValue(unsigned int i) { + return ((uint32_t *)getData())[i]; +} + +const uint32_t &BCQTensor::getValue(unsigned int b, unsigned int c, + unsigned int h, unsigned int w) const { + return getValue(getIndex(b, c, h, w / 32)); +} + +uint32_t &BCQTensor::getValue(unsigned int b, unsigned int c, unsigned int h, + unsigned int w) { + return getValue(getIndex(b, c, h, w / 32)); +} + +void BCQTensor::setValue(float value) { + uint32_t *data = (uint32_t *)getData(); + std::fill(data, data + size(), (uint32_t)value); +} + +void BCQTensor::setValue(unsigned int b, unsigned int c, unsigned int h, + unsigned int w, float value) { + ((uint32_t *)getData())[getIndex(b, c, h, w / 32)] = (uint32_t)value; +} + +void BCQTensor::addValue(unsigned int b, unsigned int c, unsigned int h, + unsigned int w, float value, float beta) { + throw std::invalid_argument("addValue() is not valid for " + + getStringDataType()); +} + +void BCQTensor::setZero() { + /// @todo replace with apply_i or scal + setValue(0); +} + +void BCQTensor::initialize() { + if (empty() || !isAllocated()) + return; + + /// @note Sampling from the normal/uniform distribution is invalid + switch (initializer) { + case Initializer::ZEROS: + setZero(); + break; + case Initializer::ONES: + setValue(1.0f); + break; + case Initializer::NONE: + break; + default: + throw std::invalid_argument("Initializer not valid for " + + getStringDataType()); + break; + } + + putData(); +} + +void BCQTensor::initialize(Initializer init) { + initializer = init; + initialize(); +} + +Tensor &BCQTensor::dot(Tensor const &input, Tensor &output, bool trans, + bool trans_in, float beta) const { + size_t qbit_of_clusters[] = {quantized_bit_size}; + size_t size_of_clusters[] = {height()}; + const size_t number_of_cluster = 1; + + /// @note hidden_tile_size should be set as a multiple of 32. This variable is + /// related to the speed of matrixDotMatrix. The optimal value should be found + /// with various values according to the usage environment. + size_t hidden_tile_size = 32; + + BiQGEMM::BCQW bcq_weight = BiQGEMM::BCQW( + (uint32_t *)getData(), (float *)getScale(), height(), width(), + number_of_cluster, qbit_of_clusters, size_of_clusters, hidden_tile_size); + + BiQGEMM::matrixDotMatrix(output.getData(), bcq_weight, input.getData(), + input.width()); + return output; +} + +void BCQTensor::copy(const Tensor &from) { + reshape(from.getDim()); + copy(from.getData()); +} + +void BCQTensor::copyData(const Tensor &from) { + NNTR_THROW_IF(!contiguous, std::invalid_argument) + << getName() << " is not contiguous, cannot copy."; + + NNTR_THROW_IF(size() != from.size(), std::invalid_argument) + << "Size of tensor to copy must match"; + + /// @todo support copy from other data types + /// @todo check data type properly + switch (from.getDataType()) { + case ml::train::TensorDim::DataType::BCQ: + copy(from.getData()); + default: + throw std::invalid_argument("Error: Unsupported data type"); + break; + } +} + +void BCQTensor::copy_with_stride(const Tensor &input, Tensor &output) { + for (unsigned int b = 0; b < output.batch(); ++b) { + for (unsigned int c = 0; c < output.channel(); ++c) { + for (unsigned int h = 0; h < output.height(); ++h) { + for (unsigned int w = 0; w < output.width() / 32; ++w) { + output.setValue(b, c, h, w, input.getValue(b, c, h, w)); + } + } + } + } +} + +std::vector BCQTensor::argmax() const { + std::vector result; + const uint32_t *data = (uint32_t *)getData(); + size_t batch_size = batch(); + size_t feature_len = dim.getFeatureLen(); + + result.resize(batch_size); + + for (unsigned int b = 0; b < batch_size; b++) { + auto max_iter = + std::max_element(data + b * feature_len, data + (b + 1) * feature_len); + result[b] = std::distance(data, max_iter) - (b * feature_len); + } + return result; +} + +// void BCQTensor::save_quant_bit(std::ostream &file) { return; } + +// void BCQTensor::read_quant_bit(std::ifstream &file) { +// file.read((char *)&quantized_bit_size, sizeof(uint16_t)); +// } + +size_t BCQTensor::size() const { + return quantized_bit_size * dim.height() * ((dim.width() + 31) / 32); +} + +float BCQTensor::max_abs() const { return maxValue(); } + +float BCQTensor::maxValue() const { + const uint32_t *data = (uint32_t *)getData(); + return *std::max_element(data, data + size()); +} + +float BCQTensor::minValue() const { + const uint32_t *data = (uint32_t *)getData(); + return *std::min_element(data, data + size()); +} + +void BCQTensor::print(std::ostream &out) const { + const uint32_t *data = (uint32_t *)getData(); + unsigned int len = size(); + out << "data addr: " << reinterpret_cast(data) << '\n'; + out << dim; + + if (len > 512) { + out << '[' << (int)data[0] << ' ' << (int)data[1] << ' ' << (int)data[2] + << " ... " << (int)data[len - 3] << ' ' << (int)data[len - 2] << ' ' + << (int)data[len - 1] << ']' << std::endl; + printScales(out); + return; + } + + std::ios init(NULL); + init.copyfmt(out); + + size_t idx = 0; + for (unsigned int bit = 0; bit < quantized_bit_size; ++bit) { + for (unsigned int k = 0; k < batch(); k++) { + for (unsigned int l = 0; l < channel(); l++) { + for (unsigned int i = 0; i < height(); i++) { + for (unsigned int j = 0; j < (width() + 31) / 32; j++) { + out << data[idx++] << " "; + } + out << std::endl; + } + } + } + out << "-------" << std::endl; + } + printScales(out); +} + +size_t BCQTensor::scale_size() const { return height() * quantized_bit_size; } + +void BCQTensor::copy(const void *buf) { + NNTR_THROW_IF(!contiguous, std::invalid_argument) + << getName() << " is not contiguous, cannot copy."; + + if (buf == getData()) { + return; + } + + /// @todo need to optimize + for (unsigned int i = 0; i < size() + scale_size(); ++i) { + ((uint32_t *)getData())[i] = ((uint32_t *)buf)[i]; + } +} + +std::string BCQTensor::getStringDataType() const { return "BCQ"; } + +void BCQTensor::printScales(std::ostream &out) const { + const float *q_scales = (float *)getScale(); + unsigned int len = scale_size(); + + if (len > 50) { + out << "Scale factors: [" << (int)q_scales[0] << ' ' << (int)q_scales[1] + << ' ' << (int)q_scales[2] << " ... " << (int)q_scales[len - 3] << ' ' + << (int)q_scales[len - 2] << ' ' << (int)q_scales[len - 1] << ']' + << std::endl; + return; + } + + out << "Scale factors: "; + for (unsigned i = 0; i < scale_size(); ++i) { + out << q_scales[i] << " "; + } + out << std::endl; +} + +} // namespace nntrainer diff --git a/nntrainer/tensor/bcq_tensor.h b/nntrainer/tensor/bcq_tensor.h new file mode 100644 index 0000000000..ba42fd4fda --- /dev/null +++ b/nntrainer/tensor/bcq_tensor.h @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * @file bcq_tensor.h + * @date 06 December 2024 + * @brief This is BCQTensor class for binary-code-based quantization + * @see https://github.com/nnstreamer/nntrainer + * @author Donghyeon Jeong + * @bug No known bugs except for NYI items + */ + +#ifndef __BCQ_TENSOR_H__ +#define __BCQ_TENSOR_H__ +#ifdef __cplusplus + +#include + +namespace nntrainer { + +/** + * @class BCQTensor class + * @brief BCQTensor class for the binary-code-based quantized weight (BCQ) + */ +class BCQTensor : public TensorBase { +public: + /** + * @brief Basic Constructor of Tensor + */ + BCQTensor(std::string name_ = "", Tformat fm = Tformat::NCHW); + + /** + * @brief Construct a new BCQTensor object + * + * @param d Tensor dim for this bcq tensor + * @param alloc_now Allocate memory to this tensor or not + * @param init Initializer for the tensor + * @param name Name of the tensor + */ + BCQTensor(const TensorDim &d, bool alloc_now, + Initializer init = Initializer::NONE, std::string name = ""); + + /** + * @brief Construct a new BCQTensor object + * + * @param d Tensor dim for this tensor + * @param buf buffer + */ + BCQTensor(const TensorDim &d, const void *buf = nullptr); + + /** + * @brief Construct a new BCQTensor object + * @param rhs TensorBase object to copy + */ + BCQTensor(TensorBase &rhs) : TensorBase(rhs) {} + + /** + * @brief Basic Destructor + */ + ~BCQTensor() {} + + /** + * @brief Comparison operator overload + * @param[in] rhs Tensor to be compared with + * @note Only compares Tensor data + */ + bool operator==(const BCQTensor &rhs) const; + + /** + * @brief Comparison operator overload + * @param[in] rhs Tensor to be compared with + * @note Only compares Tensor data + */ + bool operator!=(const BCQTensor &rhs) const; + + /** + * @copydoc Tensor::allocate() + */ + void allocate() override; + + /** + * @copydoc Tensor::deallocate() + */ + void deallocate() override; + + /** + * @copydoc Tensor::getData() + */ + void *getData() const override; + + /** + * @copydoc Tensor::getData(size_t idx) + */ + void *getData(size_t idx) const override; + + /** + * @copydoc Tensor::getScale() + */ + void *getScale() const override; + + /** + * @copydoc Tensor::getScale(size_t idx) + */ + void *getScale(size_t idx) const override; + + /** + * @brief i data index + * @retval address of ith data + */ + void *getAddress(unsigned int i) override; + + /** + * @brief i data index + * @retval address of ith data + */ + const void *getAddress(unsigned int i) const override; + + /** + * @brief return value at specific location + * @param[in] i index + */ + const uint32_t &getValue(unsigned int i) const; + + /** + * @brief return value at specific location + * @param[in] i index + */ + uint32_t &getValue(unsigned int i); + + /** + * @brief return value at specific location + * @param[in] b batch location + * @param[in] c channel location + * @param[in] h height location + * @param[in] w width location + */ + const uint32_t &getValue(unsigned int b, unsigned int c, unsigned int h, + unsigned int w) const; + + /** + * @brief return value at specific location + * @param[in] b batch location + * @param[in] c channel location + * @param[in] h height location + * @param[in] w width location + */ + uint32_t &getValue(unsigned int b, unsigned int c, unsigned int h, + unsigned int w); + + /** + * @copydoc Tensor::setValue(float value) + */ + void setValue(float value) override; + + /** + * @copydoc Tensor::setValue(b, c, h, w, value) + */ + void setValue(unsigned int b, unsigned int c, unsigned int h, unsigned int w, + float value) override; + + /** + * @copydoc Tensor::addValue(b, c, h, w, value, beta) + */ + void addValue(unsigned int b, unsigned int c, unsigned int h, unsigned int w, + float value, float beta) override; + + /** + * @copydoc Tensor::setZero() + */ + void setZero() override; + + /** + * @copydoc Tensor::initialize() + */ + void initialize() override; + + /** + * @copydoc Tensor::initialize(Initializer init) + */ + void initialize(Initializer init) override; + + /** + * @copydoc Tensor::dot(Tensor const &input, Tensor &output, bool + * trans, bool trans_in, float beta) + * + * @note BCQTensor::dot ignores trans, trans_in, and beta currently + */ + Tensor &dot(Tensor const &input, Tensor &output, bool trans, bool trans_in, + float beta) const override; + + /** + * @copydoc Tensor::copy(const Tensor &from) + */ + void copy(const Tensor &from) override; + + /** + * @copydoc Tensor::copyData(const Tensor &from) + */ + void copyData(const Tensor &from) override; + + /** + * @copydoc Tensor::copy_with_stride() + */ + void copy_with_stride(const Tensor &input, Tensor &output) override; + + /** + * @copydoc Tensor::argmax() + */ + std::vector argmax() const override; + + /** + * @copydoc TensorBase::save_quant_bit(std::ostream &file) + */ + // void save_quant_bit(std::ostream &file) override; + + /** + * @copydoc TensorBase::read_quant_bit(std::ifstream &file) + */ + // void read_quant_bit(std::ifstream &file) override; + + /** + * @copydoc TensorBase::size() + */ + size_t size() const override; + /** + * @copydoc Tensor::max_abs() + */ + float max_abs() const override; + + /** + * @copydoc Tensor::maxValue() + */ + float maxValue() const override; + + /** + * @copydoc Tensor::minValue() + */ + float minValue() const override; + + /** + * @copydoc Tensor::print(std::ostream &out) + */ + void print(std::ostream &out) const override; + + /** + * @copydoc Tensor::scale_size() + */ + size_t scale_size() const override; + +private: + /// @note this is an arbitrary value + uint16_t quantized_bit_size = 3; + + /** + * @brief copy a buffer to @a this, the caller has to ensure that @a this is + * initialized otherwise undefined behavior + * + * @param buf buffer to copy from + */ + void copy(const void *buf); + + /** + * @brief Get the Data Type String object + * @return std::string of tensor data type + */ + std::string getStringDataType() const override; + + /** + * @copydoc Tensor::isValid() + */ + bool isValid() const override { return true; }; + + /** + * @brief print quantization scale factors + */ + void printScales(std::ostream &out) const; +}; + +} // namespace nntrainer + +#endif /* __cplusplus */ +#endif /* __BCQ_TENSOR_H__ */ diff --git a/nntrainer/tensor/meson.build b/nntrainer/tensor/meson.build index a9d05043c0..83f25d2e95 100644 --- a/nntrainer/tensor/meson.build +++ b/nntrainer/tensor/meson.build @@ -76,6 +76,12 @@ if get_option('enable-fp16') tensor_sources += 'half_tensor.cpp' endif +if get_option('enable-biqgemm') + tensor_headers += 'bcq_tensor.h' + tensor_sources += 'bcq_tensor.cpp' + nntrainer_inc += biqgemm_inc +endif + if get_option('enable-opencl') subdir('cl_operations') nntrainer_inc += include_directories('cl_operations') diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index 95ed5faf62..6d0f6041b3 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -9,6 +9,7 @@ * @bug No known bugs except for NYI items */ +#include #include #include #include @@ -80,6 +81,14 @@ Tensor::Tensor(std::string name_, Tformat fm, Tdatatype d_type) { } else if (d_type == Tdatatype::QINT8) { itensor = std::shared_ptr(new CharTensor(name_, fm), std::default_delete()); + } else if (d_type == Tdatatype::BCQ) { +#ifdef ENABLE_BIQGEMM + itensor = std::shared_ptr(new BCQTensor(name_, fm), + std::default_delete()); +#else + throw std::invalid_argument("Error: enable-biqgemm is not activated. " + "Enable only if your system supports BiQGEMM."); +#endif } else { throw std::invalid_argument( "Error: Tensor cannot be constructed because the given d_type is not " @@ -120,6 +129,15 @@ Tensor::Tensor(const TensorDim &d, bool alloc_now, Initializer init, itensor = std::shared_ptr(new CharTensor(d, alloc_now, init, name), std::default_delete()); + } else if (d.getDataType() == Tdatatype::BCQ) { +#ifdef ENABLE_BIQGEMM + itensor = + std::shared_ptr(new BCQTensor(d, alloc_now, init, name), + std::default_delete()); +#else + throw std::invalid_argument("Error: enable-biqgemm is not activated. " + "Enable only if your system supports BiQGEMM."); +#endif } else { throw std::invalid_argument( "Error: Tensor cannot be constructed because the given d_type is not " @@ -153,6 +171,14 @@ Tensor::Tensor(const TensorDim &d, const void *buf) { } else if (d.getDataType() == Tdatatype::QINT8) { itensor = std::shared_ptr(new CharTensor(d, buf), std::default_delete()); + } else if (d.getDataType() == Tdatatype::BCQ) { +#ifdef ENABLE_BIQGEMM + itensor = std::shared_ptr(new BCQTensor(d, buf), + std::default_delete()); +#else + throw std::invalid_argument("Error: enable-biqgemm is not activated. " + "Enable only if your system supports BiQGEMM."); +#endif } else { throw std::invalid_argument( "Error: Tensor cannot be constructed because the given d_type is not " @@ -184,6 +210,14 @@ Tensor::Tensor(const Tensor &rhs) { } else if (rhs.getDataType() == Tdatatype::QINT8) { itensor = std::shared_ptr(new CharTensor(*rhs.itensor), std::default_delete()); + } else if (rhs.getDataType() == Tdatatype::BCQ) { +#ifdef ENABLE_BIQGEMM + itensor = std::shared_ptr(new BCQTensor(*rhs.itensor), + std::default_delete()); +#else + throw std::invalid_argument("Error: enable-biqgemm is not activated. " + "Enable only if your system supports BiQGEMM."); +#endif } } @@ -217,6 +251,14 @@ Tensor &Tensor::operator=(const Tensor &rhs) { } else if (rhs.getDataType() == Tdatatype::QINT8) { itensor = std::shared_ptr(new CharTensor(*rhs.itensor), std::default_delete()); + } else if (rhs.getDataType() == Tdatatype::BCQ) { +#ifdef ENABLE_BIQGEMM + itensor = std::shared_ptr(new BCQTensor(*rhs.itensor), + std::default_delete()); +#else + throw std::invalid_argument("Error: enable-biqgemm is not activated. " + "Enable only if your system supports BiQGEMM."); +#endif } return *this; } @@ -249,6 +291,15 @@ bool Tensor::operator==(const Tensor &rhs) const { } else if (getDataType() == Tdatatype::QINT8) { return *std::dynamic_pointer_cast(itensor) == *std::dynamic_pointer_cast(rhs.itensor); + } else if (getDataType() == Tdatatype::BCQ) { +#ifdef ENABLE_BIQGEMM + return *std::dynamic_pointer_cast(itensor) == + *std::dynamic_pointer_cast(rhs.itensor); +#else + throw std::invalid_argument( + "Error: enable-biqgemm is not activated. " + "Enable only if your system supports BiQGEMM."); +#endif } } return false; diff --git a/nntrainer/tensor/tensor_dim.cpp b/nntrainer/tensor/tensor_dim.cpp index df2a4f210b..cf403b9af1 100644 --- a/nntrainer/tensor/tensor_dim.cpp +++ b/nntrainer/tensor/tensor_dim.cpp @@ -165,6 +165,8 @@ uint TensorDim::getDataTypeSize() const { return sizeof(int8_t); case TensorDim::DataType::QINT4: return sizeof(int8_t); + case TensorDim::DataType::BCQ: + return sizeof(uint32_t); default: return sizeof(float); } @@ -392,6 +394,8 @@ std::ostream &operator<<(std::ostream &out, TensorDim const &d) { type_ = "QINT8"; } else if (d.getDataType() == ml::train::TensorDim::DataType::QINT4) { type_ = "QINT4"; + } else if (d.getDataType() == ml::train::TensorDim::DataType::BCQ) { + type_ = "BCQ"; } std::string format_ =