From 51906cf724578389d1f16fd20f3705cba6777709 Mon Sep 17 00:00:00 2001 From: Anatoliy Talamanov Date: Wed, 13 Nov 2024 16:49:34 +0000 Subject: [PATCH] NPUW: Support NF4 DCOFF for CW models (#27518) --- src/plugins/intel_cpu/src/plugin.cpp | 3 +- .../npuw/partitioning/patterns/dcoff.cpp | 13 +++- .../npuw/partitioning/patterns/dcoff.hpp | 2 +- .../intel_npu/src/plugin/npuw/util.cpp | 62 +++++++++++++++++++ 4 files changed, 76 insertions(+), 4 deletions(-) diff --git a/src/plugins/intel_cpu/src/plugin.cpp b/src/plugins/intel_cpu/src/plugin.cpp index 5c88772eeedabc..b74d4f7c8acbbb 100644 --- a/src/plugins/intel_cpu/src/plugin.cpp +++ b/src/plugins/intel_cpu/src/plugin.cpp @@ -229,7 +229,8 @@ std::shared_ptr Plugin::compile_model(const std::shared_ptr< ov::element::Type_t::f32, ov::element::Type_t::f64, ov::element::Type_t::boolean, - ov::element::Type_t::string}; + ov::element::Type_t::string, + ov::element::Type_t::nf4}; if (!supported_precisions.count(input_precision)) { OPENVINO_THROW_NOT_IMPLEMENTED("CPU plugin: Input image format ", diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp index 60f705a0c8f26c..f464f216eadb67 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp @@ -16,6 +16,7 @@ #include "openvino/op/subtract.hpp" #include "openvino/op/util/op_types.hpp" #include "openvino/pass/pattern/op/label.hpp" // any_input +#include "openvino/pass/pattern/op/optional.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/util/common_util.hpp" @@ -248,7 +249,7 @@ bool DCOFFPassBase::matcher_callback(ov::pass::pattern::Matcher& m) { auto matched_paramA = std::static_pointer_cast(matched_nodeA); auto element_type = matched_paramA->get_element_type(); - if (element_type == ov::element::i4 || element_type == ov::element::i8) { + if (element_type == ov::element::i4 || element_type == ov::element::i8 || element_type == ov::element::nf4) { LOG_DEBUG("Matched: " << matched_paramA << ", set element type to " << m_dcoff_type); matched_paramA->set_element_type(m_dcoff_type); @@ -296,7 +297,8 @@ bool DCOFFPassBase::matcher_callback(ov::pass::pattern::Matcher& m) { void DCOFFPassMatMul::build() { DCOFFPassBase::build(); auto _mmin1 = opp::any_input(); - matmul = opp::wrap_type({_mmin1, mulply}); + cvtopt = opp::optional({mulply->output(0)}); + matmul = opp::wrap_type({_mmin1, cvtopt}); register_matcher(std::make_shared(matmul, "TagDCOFFMatMul"), std::bind(&DCOFFPassMatMul::matcher_callback, this, std::placeholders::_1)); } @@ -306,6 +308,13 @@ void DCOFFPassMatMul::reconnect_root_to_convert(ov::pass::pattern::Matcher& m) { auto& node_to_output = m.get_pattern_value_map(); auto matched_convrt = node_to_output.at(toFP32).get_node_shared_ptr(); auto matched_matmul = node_to_output.at(matmul).get_node_shared_ptr(); + + auto cvt = std::static_pointer_cast(matched_convrt); + auto matmul = std::static_pointer_cast(matched_matmul); + + // NB: In case convert and matmul types don't match + cvt->set_destination_type(matmul->inputs()[1].get_element_type()); + matched_matmul->input(1).replace_source_output(matched_convrt); } diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.hpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.hpp index 55ec9ccd58835c..da06a5304c8bd7 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.hpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.hpp @@ -59,7 +59,7 @@ class DCOFFPassBase : public ov::pass::MatcherPass { ov::element::Type m_dcoff_type; DCOFFParamRef m_params_to; - std::shared_ptr paramA, paramB, toFP32, mulply; + std::shared_ptr paramA, paramB, toFP32, mulply, cvtopt; bool matcher_callback(ov::pass::pattern::Matcher& m); public: diff --git a/src/plugins/intel_npu/src/plugin/npuw/util.cpp b/src/plugins/intel_npu/src/plugin/npuw/util.cpp index 99a53430295a89..a878b244bc41e9 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/util.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/util.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include "logging.hpp" @@ -50,6 +51,59 @@ inline uint8_t hi4(uint8_t x) { inline uint8_t lo4(uint8_t x) { return x & 0xF; } + +void unpack_nf4f16(const ov::SoPtr& from, + const ov::SoPtr& scale, + const ov::SoPtr& to, + const ov::npuw::util::UnpackOptions& unpack_options) { + auto from_shape = from->get_shape(); + auto scale_shape = scale->get_shape(); + + NPUW_ASSERT(from->is_continuous()); + NPUW_ASSERT(to->is_continuous()); + NPUW_ASSERT(scale->is_continuous()); + NPUW_ASSERT(from->get_size() == to->get_size()); + NPUW_ASSERT(from_shape[0] == scale_shape[0]); + + const auto* from_ptr = static_cast(from->data()); + const auto* scale_ptr = scale->data(); + auto* to_ptr = to->data(); + + const auto size = from->get_size(); + ov::parallel_for(size / 2, [&](size_t idx) { + const uint8_t nf4_2xval = from_ptr[idx]; + const float low_scale = scale_ptr[(idx * 2) / from_shape[1]]; + const float high_scale = scale_ptr[(idx * 2 + 1) / from_shape[1]]; + to_ptr[idx * 2] = ov::ConvertNF4::dequantize(lo4(nf4_2xval)) * low_scale; + to_ptr[idx * 2 + 1] = ov::ConvertNF4::dequantize(hi4(nf4_2xval)) * high_scale; + }); + if (size % 2 != 0) { + const float low_scale = scale_ptr[size - 1 / from_shape[1]]; + to_ptr[size - 1] = ov::ConvertNF4::dequantize(lo4(from_ptr[size / 2 + 1])) * low_scale; + } +} + +void unpack_nf4f16(const ov::SoPtr& from, + const ov::SoPtr& to, + const ov::npuw::util::UnpackOptions& unpack_options) { + NPUW_ASSERT(from->is_continuous()); + NPUW_ASSERT(to->is_continuous()); + NPUW_ASSERT(from->get_size() == to->get_size()); + + const auto* from_ptr = static_cast(from->data()); + auto* to_ptr = to->data(); + + const auto size = from->get_size(); + ov::parallel_for(size / 2, [&](size_t idx) { + const uint8_t nf4_2xval = from_ptr[idx]; + to_ptr[idx * 2] = ov::ConvertNF4::dequantize(lo4(nf4_2xval)); + to_ptr[idx * 2 + 1] = ov::ConvertNF4::dequantize(hi4(nf4_2xval)); + }); + if (size % 2 != 0) { + to_ptr[size - 1] = ov::ConvertNF4::dequantize(lo4(from_ptr[size / 2 + 1])); + } +} + } // namespace ov::Tensor ov::npuw::util::tensor_from_const(const std::shared_ptr& node) { @@ -81,6 +135,12 @@ void ov::npuw::util::unpack(const ov::SoPtr& from, auto type_from = from->get_element_type(); auto type_to = to->get_element_type(); + // FIXME: Move under common switch when XARCH::unpack is implemented + if (type_from == ov::element::nf4 && type_to == ov::element::f16) { + unpack_nf4f16(from, to, unpack_options); + return; + } + namespace ove = ov::element; #define CAST(x) static_cast((x).operator ove::Type_t()) #define PAIR(f, t) (CAST(f) << 16 | CAST(t)) @@ -128,6 +188,8 @@ void ov::npuw::util::unpack(const ov::SoPtr& from, } } else if (type_from == ov::element::i8) { ov::npuw::util::XARCH::unpack_i8f16_scale(from, scale, to, unpack_options); + } else if (type_from == ov::element::nf4) { + unpack_nf4f16(from, scale, to, unpack_options); } else { NPUW_ASSERT(false && "Unsupported combination"); }