diff --git a/api/ccapi/include/tensor_dim.h b/api/ccapi/include/tensor_dim.h index 19e0a9d4ca..10a1aab3f9 100644 --- a/api/ccapi/include/tensor_dim.h +++ b/api/ccapi/include/tensor_dim.h @@ -54,6 +54,7 @@ class TensorDim { enum class DataType { QINT4, /** quantized int 4*/ QINT8, /** quantized int 8*/ + BCQ, /** binary-code-based quantized*/ UINT8, /** unsigned int 8 bit */ UINT16, /** unsigned int 16 bit */ UINT32, /** unsigned int 32 bit */ diff --git a/meson.build b/meson.build index 47f493650e..815a47f576 100644 --- a/meson.build +++ b/meson.build @@ -143,6 +143,17 @@ if get_option('opencl-kernel-path') != '' message ('OpenCL kernel path set to: @0@'.format(get_option('opencl-kernel-path'))) extra_defines += '-DOPENCL_KERNEL_PATH=@0@'.format(get_option('opencl-kernel-path')) endif + +if get_option('enable-biqgemm') + # check if BiQGEMM directory exist. otherwise, throw an error + fs = import('fs') + if fs.is_dir('../BiQGEMM') + extra_defines += '-DENABLE_BIQGEMM=1' + biqgemm_inc = include_directories('../BiQGEMM') + else + error ('BiQGEMM cannot be enabled without BiQGEMM library.') + endif +endif foreach extra_arg : warning_flags if cc.has_argument (extra_arg) diff --git a/meson_options.txt b/meson_options.txt index 0095ba8a6a..316d8f2e1f 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -43,6 +43,7 @@ option('enable-openmp', type: 'boolean', value: true) option('enable-neon', type: 'boolean', value: false) option('enable-avx', type: 'boolean', value: true) option('enable-opencl', type: 'boolean', value: false) +option('enable-biqgemm', type: 'boolean', value: false) # ml-api dependency (to enable, install capi-inference from github.com/nnstreamer/api ) # To inter-operate with nnstreamer and ML-API packages, you need to enable this. diff --git a/nntrainer/tensor/bcq_tensor.cpp b/nntrainer/tensor/bcq_tensor.cpp new file mode 100644 index 0000000000..c350f2c0f9 --- /dev/null +++ b/nntrainer/tensor/bcq_tensor.cpp @@ -0,0 +1,365 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * @file bcq_tensor.cpp + * @date 06 December 2024 + * @brief This is BCQTensor class for binary-code-based quantization + * @see https://github.com/nnstreamer/nntrainer + * @author Donghyeon Jeong + * @bug No known bugs except for NYI items + */ + +#include +#include + +#include +#include +#include +#include + +#include "BiQGEMM.h" + +namespace nntrainer { + +BCQTensor::BCQTensor(std::string name_, Tformat fm) : + TensorBase(name_, fm, Tdatatype::BCQ) {} + +BCQTensor::BCQTensor(const TensorDim &d, bool alloc_now, Initializer init, + std::string name) : + TensorBase(d, alloc_now, init, name) { + if (alloc_now) + allocate(); +} + +BCQTensor::BCQTensor(const TensorDim &d, const void *buf) : BCQTensor(d, true) { + if (d.getDataLen() != 0) { + if (buf != nullptr) + copy(buf); + } +} + +bool BCQTensor::operator==(const BCQTensor &rhs) const { + const uint32_t *_data = (uint32_t *)getData(); + const uint32_t *_rdata = (uint32_t *)rhs.getData(); + + for (size_t i = 0; i < size() + scale_size(); ++i) { + if (_data[i] - _rdata[i]) + return false; + } + + return true; +} + +void BCQTensor::allocate() { + if (empty() || data) + return; + + if (src_tensor) { + /// allocate data based on the source tensor + allocateSrcTensor(); + /** as this memory is shared, do NOT initialize */ + } else { + /// allocate new memory for the tensor data + MemoryData *mem_data; + + mem_data = new MemoryData((void *)(new uint32_t[size() + scale_size()]{})); + data = std::shared_ptr(mem_data, [](auto *mem_data) { + delete[] mem_data->template getAddr(); + delete mem_data; + }); + + offset = 0; + initialize(); + } +} + +void BCQTensor::deallocate() { + data = nullptr; + offset = 0; +} + +void *BCQTensor::getData() const { + if (!data) + return nullptr; + + data->validate(); + return data->getAddr() + offset; +} + +void *BCQTensor::getData(size_t idx) const { + NNTR_THROW_IF(idx > dim.getDataLen(), std::invalid_argument) + << "Tensor::getData() index is not valid"; + + if (!data) + return nullptr; + + data->validate(); + return data->getAddr() + offset + (idx / 32); +} + +void *BCQTensor::getScale() const { + if (!data) + return nullptr; + + data->validate(); + return ((uint32_t *)getData()) + size(); +} + +void *BCQTensor::getScale(size_t idx) const { + NNTR_THROW_IF(idx > scale_size(), std::invalid_argument) + << "Tensor::getScale() index is not valid"; + + if (!data) + return nullptr; + + data->validate(); + return ((uint32_t *)getScale()) + idx; +} + +void *BCQTensor::getAddress(unsigned int i) { + size_t index = getIndex(batch(), channel(), height(), width() / 32); + if (i > index) { + return nullptr; + } + return &((uint32_t *)getData())[i]; +} + +const void *BCQTensor::getAddress(unsigned int i) const { + size_t index = getIndex(batch(), channel(), height(), width() / 32); + if (i > index) { + return nullptr; + } + return &((uint32_t *)getData())[i]; +} + +const uint32_t &BCQTensor::getValue(unsigned int i) const { + return ((uint32_t *)getData())[i]; +} + +uint32_t &BCQTensor::getValue(unsigned int i) { + return ((uint32_t *)getData())[i]; +} + +const uint32_t &BCQTensor::getValue(unsigned int b, unsigned int c, + unsigned int h, unsigned int w) const { + return getValue(getIndex(b, c, h, w / 32)); +} + +uint32_t &BCQTensor::getValue(unsigned int b, unsigned int c, unsigned int h, + unsigned int w) { + return getValue(getIndex(b, c, h, w / 32)); +} + +void BCQTensor::setValue(float value) { + uint32_t *data = (uint32_t *)getData(); + std::fill(data, data + size(), (uint32_t)value); +} + +void BCQTensor::setValue(unsigned int b, unsigned int c, unsigned int h, + unsigned int w, float value) { + ((uint32_t *)getData())[getIndex(b, c, h, w / 32)] = (uint32_t)value; +} + +void BCQTensor::addValue(unsigned int b, unsigned int c, unsigned int h, + unsigned int w, float value, float beta) { + throw std::invalid_argument("addValue() is not valid for " + + getStringDataType()); +} + +void BCQTensor::setZero() { + /// @todo replace with apply_i or scal + setValue(0); +} + +void BCQTensor::initialize() { + if (empty() || !isAllocated()) + return; + + /// @note Sampling from the normal/uniform distribution is invalid + switch (initializer) { + case Initializer::ZEROS: + setZero(); + break; + case Initializer::ONES: + setValue(1.0f); + break; + case Initializer::NONE: + break; + default: + throw std::invalid_argument("Initializer not valid for " + + getStringDataType()); + break; + } + + putData(); +} + +void BCQTensor::initialize(Initializer init) { + initializer = init; + initialize(); +} + +Tensor &BCQTensor::dot(Tensor const &input, Tensor &output, bool trans, + bool trans_in, float beta) const { + size_t qbit_of_clusters[] = {quantized_bit_size}; + size_t size_of_clusters[] = {height()}; + const size_t number_of_cluster = 1; + + /// @note hidden_tile_size should be set as a multiple of 32. This variable is + /// related to the speed of matrixDotMatrix. The optimal value should be found + /// with various values according to the usage environment. + size_t hidden_tile_size = 32; + + BiQGEMM::BCQW bcq_weight = BiQGEMM::BCQW( + (uint32_t *)getData(), (float *)getScale(), height(), width(), + number_of_cluster, qbit_of_clusters, size_of_clusters, hidden_tile_size); + + BiQGEMM::matrixDotMatrix(output.getData(), bcq_weight, input.getData(), + input.width()); + return output; +} + +void BCQTensor::copy(const Tensor &from) { + reshape(from.getDim()); + copy(from.getData()); +} + +void BCQTensor::copyData(const Tensor &from) { + NNTR_THROW_IF(!contiguous, std::invalid_argument) + << getName() << " is not contiguous, cannot copy."; + + NNTR_THROW_IF(size() != from.size(), std::invalid_argument) + << "Size of tensor to copy must match"; + + /// @todo support copy from other data types + /// @todo check data type properly + switch (from.getDataType()) { + case ml::train::TensorDim::DataType::BCQ: + copy(from.getData()); + default: + throw std::invalid_argument("Error: Unsupported data type"); + break; + } +} + +void BCQTensor::copy_with_stride(const Tensor &input, Tensor &output) { + for (unsigned int b = 0; b < output.batch(); ++b) { + for (unsigned int c = 0; c < output.channel(); ++c) { + for (unsigned int h = 0; h < output.height(); ++h) { + for (unsigned int w = 0; w < output.width() / 32; ++w) { + output.setValue(b, c, h, w, input.getValue(b, c, h, w)); + } + } + } + } +} + +std::vector BCQTensor::argmax() const { + std::vector result; + const uint32_t *data = (uint32_t *)getData(); + size_t batch_size = batch(); + size_t feature_len = dim.getFeatureLen(); + + result.resize(batch_size); + + for (unsigned int b = 0; b < batch_size; b++) { + auto max_iter = + std::max_element(data + b * feature_len, data + (b + 1) * feature_len); + result[b] = std::distance(data, max_iter) - (b * feature_len); + } + return result; +} + +// void BCQTensor::save_quant_bit(std::ostream &file) { return; } + +// void BCQTensor::read_quant_bit(std::ifstream &file) { +// file.read((char *)&quantized_bit_size, sizeof(uint16_t)); +// } + +size_t BCQTensor::size() const { + return quantized_bit_size * dim.height() * ((dim.width() + 31) / 32); +} + +float BCQTensor::max_abs() const { return maxValue(); } + +float BCQTensor::maxValue() const { + const uint32_t *data = (uint32_t *)getData(); + return *std::max_element(data, data + size()); +} + +float BCQTensor::minValue() const { + const uint32_t *data = (uint32_t *)getData(); + return *std::min_element(data, data + size()); +} + +void BCQTensor::print(std::ostream &out) const { + const uint32_t *data = (uint32_t *)getData(); + unsigned int len = size(); + out << "data addr: " << reinterpret_cast(data) << '\n'; + out << dim; + + if (len > 512) { + out << '[' << (int)data[0] << ' ' << (int)data[1] << ' ' << (int)data[2] + << " ... " << (int)data[len - 3] << ' ' << (int)data[len - 2] << ' ' + << (int)data[len - 1] << ']' << std::endl; + printScales(out); + return; + } + + std::ios init(NULL); + init.copyfmt(out); + + size_t idx = 0; + for (unsigned int bit = 0; bit < quantized_bit_size; ++bit) { + for (unsigned int k = 0; k < batch(); k++) { + for (unsigned int l = 0; l < channel(); l++) { + for (unsigned int i = 0; i < height(); i++) { + for (unsigned int j = 0; j < (width() + 31) / 32; j++) { + out << data[idx++] << " "; + } + out << std::endl; + } + } + } + out << "-------" << std::endl; + } + printScales(out); +} + +size_t BCQTensor::scale_size() const { return height() * quantized_bit_size; } + +void BCQTensor::copy(const void *buf) { + NNTR_THROW_IF(!contiguous, std::invalid_argument) + << getName() << " is not contiguous, cannot copy."; + + if (buf == getData()) { + return; + } + + /// @todo need to optimize + for (unsigned int i = 0; i < size() + scale_size(); ++i) { + ((uint32_t *)getData())[i] = ((uint32_t *)buf)[i]; + } +} + +std::string BCQTensor::getStringDataType() const { return "BCQ"; } + +void BCQTensor::printScales(std::ostream &out) const { + const float *q_scales = (float *)getScale(); + unsigned int len = scale_size(); + + if (len > 50) { + out << "Scale factors: [" << (int)q_scales[0] << ' ' << (int)q_scales[1] + << ' ' << (int)q_scales[2] << " ... " << (int)q_scales[len - 3] << ' ' + << (int)q_scales[len - 2] << ' ' << (int)q_scales[len - 1] << ']' + << std::endl; + return; + } + + out << "Scale factors: "; + for (unsigned i = 0; i < scale_size(); ++i) { + out << q_scales[i] << " "; + } + out << std::endl; +} + +} // namespace nntrainer diff --git a/nntrainer/tensor/bcq_tensor.h b/nntrainer/tensor/bcq_tensor.h new file mode 100644 index 0000000000..ba42fd4fda --- /dev/null +++ b/nntrainer/tensor/bcq_tensor.h @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * @file bcq_tensor.h + * @date 06 December 2024 + * @brief This is BCQTensor class for binary-code-based quantization + * @see https://github.com/nnstreamer/nntrainer + * @author Donghyeon Jeong + * @bug No known bugs except for NYI items + */ + +#ifndef __BCQ_TENSOR_H__ +#define __BCQ_TENSOR_H__ +#ifdef __cplusplus + +#include + +namespace nntrainer { + +/** + * @class BCQTensor class + * @brief BCQTensor class for the binary-code-based quantized weight (BCQ) + */ +class BCQTensor : public TensorBase { +public: + /** + * @brief Basic Constructor of Tensor + */ + BCQTensor(std::string name_ = "", Tformat fm = Tformat::NCHW); + + /** + * @brief Construct a new BCQTensor object + * + * @param d Tensor dim for this bcq tensor + * @param alloc_now Allocate memory to this tensor or not + * @param init Initializer for the tensor + * @param name Name of the tensor + */ + BCQTensor(const TensorDim &d, bool alloc_now, + Initializer init = Initializer::NONE, std::string name = ""); + + /** + * @brief Construct a new BCQTensor object + * + * @param d Tensor dim for this tensor + * @param buf buffer + */ + BCQTensor(const TensorDim &d, const void *buf = nullptr); + + /** + * @brief Construct a new BCQTensor object + * @param rhs TensorBase object to copy + */ + BCQTensor(TensorBase &rhs) : TensorBase(rhs) {} + + /** + * @brief Basic Destructor + */ + ~BCQTensor() {} + + /** + * @brief Comparison operator overload + * @param[in] rhs Tensor to be compared with + * @note Only compares Tensor data + */ + bool operator==(const BCQTensor &rhs) const; + + /** + * @brief Comparison operator overload + * @param[in] rhs Tensor to be compared with + * @note Only compares Tensor data + */ + bool operator!=(const BCQTensor &rhs) const; + + /** + * @copydoc Tensor::allocate() + */ + void allocate() override; + + /** + * @copydoc Tensor::deallocate() + */ + void deallocate() override; + + /** + * @copydoc Tensor::getData() + */ + void *getData() const override; + + /** + * @copydoc Tensor::getData(size_t idx) + */ + void *getData(size_t idx) const override; + + /** + * @copydoc Tensor::getScale() + */ + void *getScale() const override; + + /** + * @copydoc Tensor::getScale(size_t idx) + */ + void *getScale(size_t idx) const override; + + /** + * @brief i data index + * @retval address of ith data + */ + void *getAddress(unsigned int i) override; + + /** + * @brief i data index + * @retval address of ith data + */ + const void *getAddress(unsigned int i) const override; + + /** + * @brief return value at specific location + * @param[in] i index + */ + const uint32_t &getValue(unsigned int i) const; + + /** + * @brief return value at specific location + * @param[in] i index + */ + uint32_t &getValue(unsigned int i); + + /** + * @brief return value at specific location + * @param[in] b batch location + * @param[in] c channel location + * @param[in] h height location + * @param[in] w width location + */ + const uint32_t &getValue(unsigned int b, unsigned int c, unsigned int h, + unsigned int w) const; + + /** + * @brief return value at specific location + * @param[in] b batch location + * @param[in] c channel location + * @param[in] h height location + * @param[in] w width location + */ + uint32_t &getValue(unsigned int b, unsigned int c, unsigned int h, + unsigned int w); + + /** + * @copydoc Tensor::setValue(float value) + */ + void setValue(float value) override; + + /** + * @copydoc Tensor::setValue(b, c, h, w, value) + */ + void setValue(unsigned int b, unsigned int c, unsigned int h, unsigned int w, + float value) override; + + /** + * @copydoc Tensor::addValue(b, c, h, w, value, beta) + */ + void addValue(unsigned int b, unsigned int c, unsigned int h, unsigned int w, + float value, float beta) override; + + /** + * @copydoc Tensor::setZero() + */ + void setZero() override; + + /** + * @copydoc Tensor::initialize() + */ + void initialize() override; + + /** + * @copydoc Tensor::initialize(Initializer init) + */ + void initialize(Initializer init) override; + + /** + * @copydoc Tensor::dot(Tensor const &input, Tensor &output, bool + * trans, bool trans_in, float beta) + * + * @note BCQTensor::dot ignores trans, trans_in, and beta currently + */ + Tensor &dot(Tensor const &input, Tensor &output, bool trans, bool trans_in, + float beta) const override; + + /** + * @copydoc Tensor::copy(const Tensor &from) + */ + void copy(const Tensor &from) override; + + /** + * @copydoc Tensor::copyData(const Tensor &from) + */ + void copyData(const Tensor &from) override; + + /** + * @copydoc Tensor::copy_with_stride() + */ + void copy_with_stride(const Tensor &input, Tensor &output) override; + + /** + * @copydoc Tensor::argmax() + */ + std::vector argmax() const override; + + /** + * @copydoc TensorBase::save_quant_bit(std::ostream &file) + */ + // void save_quant_bit(std::ostream &file) override; + + /** + * @copydoc TensorBase::read_quant_bit(std::ifstream &file) + */ + // void read_quant_bit(std::ifstream &file) override; + + /** + * @copydoc TensorBase::size() + */ + size_t size() const override; + /** + * @copydoc Tensor::max_abs() + */ + float max_abs() const override; + + /** + * @copydoc Tensor::maxValue() + */ + float maxValue() const override; + + /** + * @copydoc Tensor::minValue() + */ + float minValue() const override; + + /** + * @copydoc Tensor::print(std::ostream &out) + */ + void print(std::ostream &out) const override; + + /** + * @copydoc Tensor::scale_size() + */ + size_t scale_size() const override; + +private: + /// @note this is an arbitrary value + uint16_t quantized_bit_size = 3; + + /** + * @brief copy a buffer to @a this, the caller has to ensure that @a this is + * initialized otherwise undefined behavior + * + * @param buf buffer to copy from + */ + void copy(const void *buf); + + /** + * @brief Get the Data Type String object + * @return std::string of tensor data type + */ + std::string getStringDataType() const override; + + /** + * @copydoc Tensor::isValid() + */ + bool isValid() const override { return true; }; + + /** + * @brief print quantization scale factors + */ + void printScales(std::ostream &out) const; +}; + +} // namespace nntrainer + +#endif /* __cplusplus */ +#endif /* __BCQ_TENSOR_H__ */ diff --git a/nntrainer/tensor/meson.build b/nntrainer/tensor/meson.build index a9d05043c0..83f25d2e95 100644 --- a/nntrainer/tensor/meson.build +++ b/nntrainer/tensor/meson.build @@ -76,6 +76,12 @@ if get_option('enable-fp16') tensor_sources += 'half_tensor.cpp' endif +if get_option('enable-biqgemm') + tensor_headers += 'bcq_tensor.h' + tensor_sources += 'bcq_tensor.cpp' + nntrainer_inc += biqgemm_inc +endif + if get_option('enable-opencl') subdir('cl_operations') nntrainer_inc += include_directories('cl_operations') diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index 95ed5faf62..6d0f6041b3 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -9,6 +9,7 @@ * @bug No known bugs except for NYI items */ +#include #include #include #include @@ -80,6 +81,14 @@ Tensor::Tensor(std::string name_, Tformat fm, Tdatatype d_type) { } else if (d_type == Tdatatype::QINT8) { itensor = std::shared_ptr(new CharTensor(name_, fm), std::default_delete()); + } else if (d_type == Tdatatype::BCQ) { +#ifdef ENABLE_BIQGEMM + itensor = std::shared_ptr(new BCQTensor(name_, fm), + std::default_delete()); +#else + throw std::invalid_argument("Error: enable-biqgemm is not activated. " + "Enable only if your system supports BiQGEMM."); +#endif } else { throw std::invalid_argument( "Error: Tensor cannot be constructed because the given d_type is not " @@ -120,6 +129,15 @@ Tensor::Tensor(const TensorDim &d, bool alloc_now, Initializer init, itensor = std::shared_ptr(new CharTensor(d, alloc_now, init, name), std::default_delete()); + } else if (d.getDataType() == Tdatatype::BCQ) { +#ifdef ENABLE_BIQGEMM + itensor = + std::shared_ptr(new BCQTensor(d, alloc_now, init, name), + std::default_delete()); +#else + throw std::invalid_argument("Error: enable-biqgemm is not activated. " + "Enable only if your system supports BiQGEMM."); +#endif } else { throw std::invalid_argument( "Error: Tensor cannot be constructed because the given d_type is not " @@ -153,6 +171,14 @@ Tensor::Tensor(const TensorDim &d, const void *buf) { } else if (d.getDataType() == Tdatatype::QINT8) { itensor = std::shared_ptr(new CharTensor(d, buf), std::default_delete()); + } else if (d.getDataType() == Tdatatype::BCQ) { +#ifdef ENABLE_BIQGEMM + itensor = std::shared_ptr(new BCQTensor(d, buf), + std::default_delete()); +#else + throw std::invalid_argument("Error: enable-biqgemm is not activated. " + "Enable only if your system supports BiQGEMM."); +#endif } else { throw std::invalid_argument( "Error: Tensor cannot be constructed because the given d_type is not " @@ -184,6 +210,14 @@ Tensor::Tensor(const Tensor &rhs) { } else if (rhs.getDataType() == Tdatatype::QINT8) { itensor = std::shared_ptr(new CharTensor(*rhs.itensor), std::default_delete()); + } else if (rhs.getDataType() == Tdatatype::BCQ) { +#ifdef ENABLE_BIQGEMM + itensor = std::shared_ptr(new BCQTensor(*rhs.itensor), + std::default_delete()); +#else + throw std::invalid_argument("Error: enable-biqgemm is not activated. " + "Enable only if your system supports BiQGEMM."); +#endif } } @@ -217,6 +251,14 @@ Tensor &Tensor::operator=(const Tensor &rhs) { } else if (rhs.getDataType() == Tdatatype::QINT8) { itensor = std::shared_ptr(new CharTensor(*rhs.itensor), std::default_delete()); + } else if (rhs.getDataType() == Tdatatype::BCQ) { +#ifdef ENABLE_BIQGEMM + itensor = std::shared_ptr(new BCQTensor(*rhs.itensor), + std::default_delete()); +#else + throw std::invalid_argument("Error: enable-biqgemm is not activated. " + "Enable only if your system supports BiQGEMM."); +#endif } return *this; } @@ -249,6 +291,15 @@ bool Tensor::operator==(const Tensor &rhs) const { } else if (getDataType() == Tdatatype::QINT8) { return *std::dynamic_pointer_cast(itensor) == *std::dynamic_pointer_cast(rhs.itensor); + } else if (getDataType() == Tdatatype::BCQ) { +#ifdef ENABLE_BIQGEMM + return *std::dynamic_pointer_cast(itensor) == + *std::dynamic_pointer_cast(rhs.itensor); +#else + throw std::invalid_argument( + "Error: enable-biqgemm is not activated. " + "Enable only if your system supports BiQGEMM."); +#endif } } return false; diff --git a/nntrainer/tensor/tensor_dim.cpp b/nntrainer/tensor/tensor_dim.cpp index df2a4f210b..cf403b9af1 100644 --- a/nntrainer/tensor/tensor_dim.cpp +++ b/nntrainer/tensor/tensor_dim.cpp @@ -165,6 +165,8 @@ uint TensorDim::getDataTypeSize() const { return sizeof(int8_t); case TensorDim::DataType::QINT4: return sizeof(int8_t); + case TensorDim::DataType::BCQ: + return sizeof(uint32_t); default: return sizeof(float); } @@ -392,6 +394,8 @@ std::ostream &operator<<(std::ostream &out, TensorDim const &d) { type_ = "QINT8"; } else if (d.getDataType() == ml::train::TensorDim::DataType::QINT4) { type_ = "QINT4"; + } else if (d.getDataType() == ml::train::TensorDim::DataType::BCQ) { + type_ = "BCQ"; } std::string format_ =