From d1999948064d95581a0b11c20265e7757423687b Mon Sep 17 00:00:00 2001 From: Pawel Raasz Date: Wed, 17 Apr 2024 14:35:53 +0200 Subject: [PATCH] [core] Convert limit NF4 conversion to FP16 -> NF4 (#23806) ### Details: - Limit NF4 conversion to FP16 -> NF4 in convert operator as values are correctly quantized, Not support NF4 conversion to/from other types. ### Tickets: - [CVS-135304](https://jira.devtools.intel.com/browse/CVS-135304) --- .../test_compression_4bit.py | 9 +- src/core/src/op/convert.cpp | 59 +++++-- .../src/matmul_weights_decompression.cpp | 5 +- src/plugins/template/backend/ops/convert.cpp | 159 +++++++++--------- .../template/backend/ops/ops_evaluates.hpp | 4 + .../functional/op_reference/convert_like.cpp | 22 ++- 6 files changed, 163 insertions(+), 95 deletions(-) diff --git a/src/bindings/python/tests/test_transformations/test_compression_4bit.py b/src/bindings/python/tests/test_transformations/test_compression_4bit.py index c72818b0339f7d..4d167f4bd52781 100644 --- a/src/bindings/python/tests/test_transformations/test_compression_4bit.py +++ b/src/bindings/python/tests/test_transformations/test_compression_4bit.py @@ -18,11 +18,14 @@ def test_float_to_nf4_convert(ov_type, numpy_dtype): data = np.linspace(-1.5, 1.5, num=41, dtype=numpy_dtype) + # Compress data to NF4 compressed_const = opset.constant(data, dtype=ov.Type.nf4, name="nf4_constant") - convert = opset.convert(compressed_const, data.dtype) + # get decompressed data as tested OV type + decompressed = opset.convert(compressed_const, ov_type) + parameter = opset.parameter(ov.PartialShape([-1]), ov_type) - add_op = opset.add(parameter, convert) - model = ov.Model([add_op], [parameter]) + output = opset.add(parameter, decompressed) + model = ov.Model([output], [parameter]) compiled = ov.compile_model(model) tensor = np.zeros(data.shape, dtype=numpy_dtype) diff --git a/src/core/src/op/convert.cpp b/src/core/src/op/convert.cpp index 97df7026d818ec..5917ba191f0088 100644 --- a/src/core/src/op/convert.cpp +++ b/src/core/src/op/convert.cpp @@ -10,6 +10,7 @@ #include "openvino/op/equal.hpp" #include "openvino/op/select.hpp" #include "openvino/reference/convert.hpp" +#include "openvino/reference/utils/type_util.hpp" namespace ov { namespace op { @@ -18,10 +19,31 @@ namespace convert { #define CONVERT_ET_LIST \ boolean, bf16, f16, f32, f64, i4, i8, i16, i32, i64, u1, u2, u3, u4, u6, u8, u16, u32, u64, nf4, f8e4m3, f8e5m2 +#define CONVERT_TO_ANY_NO_NF4 \ + boolean, bf16, f16, f32, f64, i4, i8, i16, i32, i64, u1, u2, u3, u4, u6, u8, u16, u32, u64, f8e4m3, f8e5m2 + struct Evaluate : public element::NoAction { using element::NoAction::visit; - template > + // convert from any (except F16, NF4) to any except NF4 + template , + typename std::enable_if::type* = nullptr> + static result_type visit(const Tensor& arg, Tensor& out, const size_t count) { + using namespace ov::element; + return IF_TYPE_OF(Convert_out, + CONVERT_TO_ANY_NO_NF4, + EvalByOutputType, + out.get_element_type(), + iterator(reinterpret_cast(arg.data())), + out, + count); + } + + // convert from F16 to any + template , + typename std::enable_if::type* = nullptr> static result_type visit(const Tensor& arg, Tensor& out, const size_t count) { using namespace ov::element; return IF_TYPE_OF(Convert_out, @@ -33,6 +55,21 @@ struct Evaluate : public element::NoAction { count); } + // convert form NF4 + template , + typename std::enable_if::type* = nullptr> + static result_type visit(const Tensor& arg, Tensor& out, const size_t count) { + using namespace ov::element; + return IF_TYPE_OF(Convert_out, + OV_PP_ET_LIST(f16, f32, nf4), + EvalByOutputType, + out.get_element_type(), + iterator(reinterpret_cast(arg.data())), + out, + count); + } + private: struct EvalByOutputType : public element::NoAction { using element::NoAction::visit; @@ -118,16 +155,12 @@ bool Convert::evaluate(TensorVector& outputs, const TensorVector& inputs) const if (auto& out = outputs[0]) { const auto& in = inputs[0]; const auto& in_shape = in.get_shape(); + const auto count = shape_size(in_shape); + out.set_shape(in_shape); using namespace ov::element; - return IF_TYPE_OF(v0_Convert_in_et, - CONVERT_ET_LIST, - convert::Evaluate, - in.get_element_type(), - in, - out, - shape_size(in_shape)); + return IF_TYPE_OF(v0_Convert_in_et, CONVERT_ET_LIST, convert::Evaluate, in.get_element_type(), in, out, count); } else { return false; } @@ -136,6 +169,10 @@ bool Convert::evaluate(TensorVector& outputs, const TensorVector& inputs) const bool Convert::has_evaluate() const { OV_OP_SCOPE(v0_Convert_has_evaluate); + const auto is_to_nf4_supported = [](const element::Type& from, const element::Type& to) { + return (from == element::nf4) && (to == element::f16 || to == element::f32 || to == element::nf4); + }; + const auto is_valid_type = [](const element::Type& et) -> bool { switch (et) { case element::boolean: @@ -154,7 +191,6 @@ bool Convert::has_evaluate() const { case element::u16: case element::u32: case element::u64: - case element::nf4: case element::f8e4m3: case element::f8e5m2: return true; @@ -163,7 +199,10 @@ bool Convert::has_evaluate() const { }; }; - return is_valid_type(get_input_element_type(0)) && is_valid_type(get_output_element_type(0)); + const auto& input_et = get_input_element_type(0); + const auto& output_et = get_output_element_type(0); + + return (is_valid_type(input_et) && is_valid_type(output_et)) || is_to_nf4_supported(input_et, output_et); } bool Convert::evaluate_lower(TensorVector& output_values) const { diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/matmul_weights_decompression.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/matmul_weights_decompression.cpp index a42d6465108557..22d44b04f5fd2e 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/matmul_weights_decompression.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/matmul_weights_decompression.cpp @@ -190,7 +190,8 @@ class MatmulWeightsDecompression : public testing::WithParamInterface(weights_precision, subtract_shape, {}, true, up_to); - std::shared_ptr shift_convert = std::make_shared(shift_const, decompression_precision); + std::shared_ptr shift_convert = + std::make_shared(shift_const, decompression_precision); if (reshape_on_decompression_constant) { auto subtract_target_shape = decompression_subtract_type == DecompressionSubtractType::full ? scaleshift_target_shape : ov::Shape(scaleshift_const_shape.size(), 1); @@ -343,7 +344,7 @@ const std::vector decompression_precisions = {ov::element const std::vector weights_precisions = {ov::element::u8, ov::element::u4, ov::element::i4, - ov::element::nf4}; + element::nf4}; const std::vector input_shapes_basic = { {{{-1, -1, -1}, {{1, 4, 16}, {10, 16, 16}}}, {16, 32}}, diff --git a/src/plugins/template/backend/ops/convert.cpp b/src/plugins/template/backend/ops/convert.cpp index d0ba5a3f5c8deb..da1a1ec35a9432 100644 --- a/src/plugins/template/backend/ops/convert.cpp +++ b/src/plugins/template/backend/ops/convert.cpp @@ -7,121 +7,124 @@ #include "evaluate_node.hpp" #include "openvino/core/type/element_iterator.hpp" -namespace convert_like_v1 { +namespace convert { template -inline void evaluate(const std::shared_ptr& op, - ov::TensorVector& outputs, - const ov::TensorVector& inputs) { - using T_I = typename ov::element_type_traits::value_type; - using T_O = typename ov::element_type_traits::value_type; +bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) { outputs[0].set_shape(inputs[0].get_shape()); - size_t element_count = ov::shape_size(outputs[0].get_shape()); + const auto element_count = ov::shape_size(outputs[0].get_shape()); - ov::reference::convert(ov::element::iterator(inputs[0].data()), - ov::element::iterator(outputs[0].data()), + ov::reference::convert(ov::element::iterator(static_cast(inputs[0].data())), + ov::element::iterator(outputs[0].data()), element_count); + return true; } -} // namespace convert_like_v1 template -bool evaluate(const std::shared_ptr& op, - ov::TensorVector& outputs, - const ov::TensorVector& inputs) { +bool evaluate_by_input_type(ov::TensorVector& outputs, const ov::TensorVector& inputs) { switch (inputs[0].get_element_type()) { case ov::element::boolean: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::u1: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::u4: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::u8: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::u16: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::u32: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::u64: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::i4: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::i8: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::i16: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::i32: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::i64: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::bf16: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::f16: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::f32: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); case ov::element::f64: - convert_like_v1::evaluate(op, outputs, inputs); - break; + return evaluate(outputs, inputs); + case ov::element::nf4: + return evaluate(outputs, inputs); default: return false; } - return true; } -template <> -bool evaluate_node(std::shared_ptr node, - ov::TensorVector& outputs, - const ov::TensorVector& inputs) { - const auto& element_type = node->get_output_element_type(0); - - switch (element_type) { +namespace { +bool evaluate_by_output_type(const ov::element::Type& output_et, + ov::TensorVector& outputs, + const ov::TensorVector& inputs) { + switch (output_et) { case ov::element::boolean: - return evaluate(ov::as_type_ptr(node), outputs, inputs); - case ov::element::bf16: - return evaluate(ov::as_type_ptr(node), outputs, inputs); - case ov::element::f16: - return evaluate(ov::as_type_ptr(node), outputs, inputs); - case ov::element::f64: - return evaluate(ov::as_type_ptr(node), outputs, inputs); - case ov::element::f32: - return evaluate(ov::as_type_ptr(node), outputs, inputs); - case ov::element::i4: - return evaluate(ov::as_type_ptr(node), outputs, inputs); - case ov::element::i8: - return evaluate(ov::as_type_ptr(node), outputs, inputs); - case ov::element::i16: - return evaluate(ov::as_type_ptr(node), outputs, inputs); - case ov::element::i32: - return evaluate(ov::as_type_ptr(node), outputs, inputs); - case ov::element::i64: - return evaluate(ov::as_type_ptr(node), outputs, inputs); + return evaluate_by_input_type(outputs, inputs); case ov::element::u1: - return evaluate(ov::as_type_ptr(node), outputs, inputs); + return evaluate_by_input_type(outputs, inputs); case ov::element::u4: - return evaluate(ov::as_type_ptr(node), outputs, inputs); + return evaluate_by_input_type(outputs, inputs); case ov::element::u8: - return evaluate(ov::as_type_ptr(node), outputs, inputs); + return evaluate_by_input_type(outputs, inputs); case ov::element::u16: - return evaluate(ov::as_type_ptr(node), outputs, inputs); + return evaluate_by_input_type(outputs, inputs); case ov::element::u32: - return evaluate(ov::as_type_ptr(node), outputs, inputs); + return evaluate_by_input_type(outputs, inputs); case ov::element::u64: - return evaluate(ov::as_type_ptr(node), outputs, inputs); + return evaluate_by_input_type(outputs, inputs); + case ov::element::i4: + return evaluate_by_input_type(outputs, inputs); + case ov::element::i8: + return evaluate_by_input_type(outputs, inputs); + case ov::element::i16: + return evaluate_by_input_type(outputs, inputs); + case ov::element::i32: + return evaluate_by_input_type(outputs, inputs); + case ov::element::i64: + return evaluate_by_input_type(outputs, inputs); + case ov::element::bf16: + return evaluate_by_input_type(outputs, inputs); + case ov::element::f16: + return evaluate_by_input_type(outputs, inputs); + case ov::element::f32: + return evaluate_by_input_type(outputs, inputs); + case ov::element::f64: + return evaluate_by_input_type(outputs, inputs); + case ov::element::nf4: + return evaluate_by_input_type(outputs, inputs); default: + return false; + } +} +} // namespace +} // namespace convert + +template <> +bool evaluate_node(std::shared_ptr node, + ov::TensorVector& outputs, + const ov::TensorVector& inputs) { + if (convert::evaluate_by_output_type(node->get_output_element_type(0), outputs, inputs)) { + return true; + } else { + OPENVINO_THROW("Unhandled data type ", node->get_element_type().get_type_name(), " in evaluate_node()"); + } +} + +template <> +bool evaluate_node(std::shared_ptr node, + ov::TensorVector& outputs, + const ov::TensorVector& inputs) { + if (convert::evaluate_by_output_type(node->get_output_element_type(0), outputs, inputs)) { + return true; + } else { OPENVINO_THROW("Unhandled data type ", node->get_element_type().get_type_name(), " in evaluate_node()"); } } diff --git a/src/plugins/template/backend/ops/ops_evaluates.hpp b/src/plugins/template/backend/ops/ops_evaluates.hpp index b621e466d04262..b4977967f47e8b 100644 --- a/src/plugins/template/backend/ops/ops_evaluates.hpp +++ b/src/plugins/template/backend/ops/ops_evaluates.hpp @@ -19,6 +19,10 @@ extern template bool evaluate_node(std::shared_ptr(std::shared_ptr node, + ov::TensorVector& outputs, + const ov::TensorVector& inputs); + extern template bool evaluate_node(std::shared_ptr node, ov::TensorVector& outputs, const ov::TensorVector& inputs); diff --git a/src/plugins/template/tests/functional/op_reference/convert_like.cpp b/src/plugins/template/tests/functional/op_reference/convert_like.cpp index 43d50daecafa79..2868aaa2950f79 100644 --- a/src/plugins/template/tests/functional/op_reference/convert_like.cpp +++ b/src/plugins/template/tests/functional/op_reference/convert_like.cpp @@ -144,6 +144,12 @@ INSTANTIATE_TEST_SUITE_P( ov::element::f16, std::vector{0, 10, 15, 20, 43, 56, 78, 99, 102, 130, 142}, std::vector{0, 10, 15, 20, 43, 56, 78, 99, 102, 130, 142}), + ConvertParams(ConversionTypes::CONVERT_LIKE, + ov::PartialShape{4}, + ov::element::nf4, + ov::element::f16, + std::vector{0xE1, 0x1F}, + std::vector{-0.6961928009986877f, 0.7229568362236023f, 1.0f, -0.6961928009986877f}), // destination f32 ConvertParams(ConversionTypes::CONVERT_LIKE, @@ -307,6 +313,12 @@ INSTANTIATE_TEST_SUITE_P( vector{0.5f, 1.5f, 0.5f, 2.5f, 1.5f, 0.5f, 3.5f, 2.5f, 0.5f, 0.5f, 2.5f, 0.5f, 0.5f, 0.5f, 1.5f}, std::vector< float>{0.5f, 1.5f, 0.5f, 2.5f, 1.5f, 0.5f, 3.5f, 2.5f, 0.5f, 0.5f, 2.5f, 0.5f, 0.5f, 0.5f, 1.5f}), + ConvertParams(ConversionTypes::CONVERT_LIKE, + ov::PartialShape{4}, + ov::element::nf4, + ov::element::f32, + std::vector{0xE1, 0x1F}, + std::vector{-0.6961928009986877f, 0.7229568362236023f, 1.0f, -0.6961928009986877f}), // destination i4 ConvertParams(ConversionTypes::CONVERT_LIKE, ov::PartialShape{4}, @@ -1448,9 +1460,15 @@ INSTANTIATE_TEST_SUITE_P( // destination nf4 (use quantization) ConvertParams(ConversionTypes::CONVERT_LIKE, ov::PartialShape{4}, - ov::element::f32, + ov::element::f16, + ov::element::nf4, + std::vector{-0.6961928009986877f, 0.7229568362236023f, 1.0f, -0.5250730514526367f}, + std::vector{0xE1, 0x2F}), + ConvertParams(ConversionTypes::CONVERT_LIKE, + ov::PartialShape{4}, + ov::element::nf4, ov::element::nf4, - std::vector{-0.6961928009986877f, 0.7229568362236023f, 1.0f, -0.5250730514526367f}, + std::vector{0xE1, 0x2f}, std::vector{0xE1, 0x2F}), // destination u2 ConvertParams(ConversionTypes::CONVERT_LIKE,