Skip to content

Commit

Permalink
NPUW: Support NF4 DCOFF for CW models (#27518)
Browse files Browse the repository at this point in the history
  • Loading branch information
TolyaTalamanov authored Nov 13, 2024
1 parent f417097 commit 51906cf
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/plugins/intel_cpu/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<
ov::element::Type_t::f32,
ov::element::Type_t::f64,
ov::element::Type_t::boolean,
ov::element::Type_t::string};
ov::element::Type_t::string,
ov::element::Type_t::nf4};

if (!supported_precisions.count(input_precision)) {
OPENVINO_THROW_NOT_IMPLEMENTED("CPU plugin: Input image format ",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "openvino/op/subtract.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/pass/pattern/op/label.hpp" // any_input
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"

Expand Down Expand Up @@ -248,7 +249,7 @@ bool DCOFFPassBase::matcher_callback(ov::pass::pattern::Matcher& m) {

auto matched_paramA = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeA);
auto element_type = matched_paramA->get_element_type();
if (element_type == ov::element::i4 || element_type == ov::element::i8) {
if (element_type == ov::element::i4 || element_type == ov::element::i8 || element_type == ov::element::nf4) {
LOG_DEBUG("Matched: " << matched_paramA << ", set element type to " << m_dcoff_type);
matched_paramA->set_element_type(m_dcoff_type);

Expand Down Expand Up @@ -296,7 +297,8 @@ bool DCOFFPassBase::matcher_callback(ov::pass::pattern::Matcher& m) {
void DCOFFPassMatMul::build() {
DCOFFPassBase::build();
auto _mmin1 = opp::any_input();
matmul = opp::wrap_type<ov::op::v0::MatMul>({_mmin1, mulply});
cvtopt = opp::optional<ov::op::v0::Convert>({mulply->output(0)});
matmul = opp::wrap_type<ov::op::v0::MatMul>({_mmin1, cvtopt});
register_matcher(std::make_shared<opp::Matcher>(matmul, "TagDCOFFMatMul"),
std::bind(&DCOFFPassMatMul::matcher_callback, this, std::placeholders::_1));
}
Expand All @@ -306,6 +308,13 @@ void DCOFFPassMatMul::reconnect_root_to_convert(ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();
auto matched_convrt = node_to_output.at(toFP32).get_node_shared_ptr();
auto matched_matmul = node_to_output.at(matmul).get_node_shared_ptr();

auto cvt = std::static_pointer_cast<ov::op::v0::Convert>(matched_convrt);
auto matmul = std::static_pointer_cast<ov::op::v0::MatMul>(matched_matmul);

// NB: In case convert and matmul types don't match
cvt->set_destination_type(matmul->inputs()[1].get_element_type());

matched_matmul->input(1).replace_source_output(matched_convrt);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class DCOFFPassBase : public ov::pass::MatcherPass {
ov::element::Type m_dcoff_type;
DCOFFParamRef m_params_to;

std::shared_ptr<ov::Node> paramA, paramB, toFP32, mulply;
std::shared_ptr<ov::Node> paramA, paramB, toFP32, mulply, cvtopt;
bool matcher_callback(ov::pass::pattern::Matcher& m);

public:
Expand Down
62 changes: 62 additions & 0 deletions src/plugins/intel_npu/src/plugin/npuw/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <openvino/core/parallel.hpp>
#include <openvino/core/type/bfloat16.hpp>
#include <openvino/core/type/float16.hpp>
#include <openvino/core/type/nf4.hpp>
#include <sstream>

#include "logging.hpp"
Expand Down Expand Up @@ -50,6 +51,59 @@ inline uint8_t hi4(uint8_t x) {
inline uint8_t lo4(uint8_t x) {
return x & 0xF;
}

void unpack_nf4f16(const ov::SoPtr<ov::ITensor>& from,
const ov::SoPtr<ov::ITensor>& scale,
const ov::SoPtr<ov::ITensor>& to,
const ov::npuw::util::UnpackOptions& unpack_options) {
auto from_shape = from->get_shape();
auto scale_shape = scale->get_shape();

NPUW_ASSERT(from->is_continuous());
NPUW_ASSERT(to->is_continuous());
NPUW_ASSERT(scale->is_continuous());
NPUW_ASSERT(from->get_size() == to->get_size());
NPUW_ASSERT(from_shape[0] == scale_shape[0]);

const auto* from_ptr = static_cast<const uint8_t*>(from->data());
const auto* scale_ptr = scale->data<ov::float16>();
auto* to_ptr = to->data<ov::float16>();

const auto size = from->get_size();
ov::parallel_for(size / 2, [&](size_t idx) {
const uint8_t nf4_2xval = from_ptr[idx];
const float low_scale = scale_ptr[(idx * 2) / from_shape[1]];
const float high_scale = scale_ptr[(idx * 2 + 1) / from_shape[1]];
to_ptr[idx * 2] = ov::ConvertNF4::dequantize(lo4(nf4_2xval)) * low_scale;
to_ptr[idx * 2 + 1] = ov::ConvertNF4::dequantize(hi4(nf4_2xval)) * high_scale;
});
if (size % 2 != 0) {
const float low_scale = scale_ptr[size - 1 / from_shape[1]];
to_ptr[size - 1] = ov::ConvertNF4::dequantize(lo4(from_ptr[size / 2 + 1])) * low_scale;
}
}

void unpack_nf4f16(const ov::SoPtr<ov::ITensor>& from,
const ov::SoPtr<ov::ITensor>& to,
const ov::npuw::util::UnpackOptions& unpack_options) {
NPUW_ASSERT(from->is_continuous());
NPUW_ASSERT(to->is_continuous());
NPUW_ASSERT(from->get_size() == to->get_size());

const auto* from_ptr = static_cast<const uint8_t*>(from->data());
auto* to_ptr = to->data<ov::float16>();

const auto size = from->get_size();
ov::parallel_for(size / 2, [&](size_t idx) {
const uint8_t nf4_2xval = from_ptr[idx];
to_ptr[idx * 2] = ov::ConvertNF4::dequantize(lo4(nf4_2xval));
to_ptr[idx * 2 + 1] = ov::ConvertNF4::dequantize(hi4(nf4_2xval));
});
if (size % 2 != 0) {
to_ptr[size - 1] = ov::ConvertNF4::dequantize(lo4(from_ptr[size / 2 + 1]));
}
}

} // namespace

ov::Tensor ov::npuw::util::tensor_from_const(const std::shared_ptr<ov::Node>& node) {
Expand Down Expand Up @@ -81,6 +135,12 @@ void ov::npuw::util::unpack(const ov::SoPtr<ov::ITensor>& from,
auto type_from = from->get_element_type();
auto type_to = to->get_element_type();

// FIXME: Move under common switch when XARCH::unpack is implemented
if (type_from == ov::element::nf4 && type_to == ov::element::f16) {
unpack_nf4f16(from, to, unpack_options);
return;
}

namespace ove = ov::element;
#define CAST(x) static_cast<int>((x).operator ove::Type_t())
#define PAIR(f, t) (CAST(f) << 16 | CAST(t))
Expand Down Expand Up @@ -128,6 +188,8 @@ void ov::npuw::util::unpack(const ov::SoPtr<ov::ITensor>& from,
}
} else if (type_from == ov::element::i8) {
ov::npuw::util::XARCH::unpack_i8f16_scale(from, scale, to, unpack_options);
} else if (type_from == ov::element::nf4) {
unpack_nf4f16(from, scale, to, unpack_options);
} else {
NPUW_ASSERT(false && "Unsupported combination");
}
Expand Down

0 comments on commit 51906cf

Please sign in to comment.