Skip to content

Commit

Permalink
[core] Convert limit NF4 conversion to FP16 -> NF4 (#23806)
Browse files Browse the repository at this point in the history
### Details:
- Limit NF4 conversion to FP16 -> NF4 in convert operator as values are
correctly quantized, Not support NF4 conversion to/from other types.

### Tickets:
 - [CVS-135304](https://jira.devtools.intel.com/browse/CVS-135304)
  • Loading branch information
praasz authored Apr 17, 2024
1 parent fd59da5 commit d199994
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
def test_float_to_nf4_convert(ov_type, numpy_dtype):
data = np.linspace(-1.5, 1.5, num=41, dtype=numpy_dtype)

# Compress data to NF4
compressed_const = opset.constant(data, dtype=ov.Type.nf4, name="nf4_constant")
convert = opset.convert(compressed_const, data.dtype)
# get decompressed data as tested OV type
decompressed = opset.convert(compressed_const, ov_type)

parameter = opset.parameter(ov.PartialShape([-1]), ov_type)
add_op = opset.add(parameter, convert)
model = ov.Model([add_op], [parameter])
output = opset.add(parameter, decompressed)
model = ov.Model([output], [parameter])

compiled = ov.compile_model(model)
tensor = np.zeros(data.shape, dtype=numpy_dtype)
Expand Down
59 changes: 49 additions & 10 deletions src/core/src/op/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "openvino/op/equal.hpp"
#include "openvino/op/select.hpp"
#include "openvino/reference/convert.hpp"
#include "openvino/reference/utils/type_util.hpp"

namespace ov {
namespace op {
Expand All @@ -18,10 +19,31 @@ namespace convert {
#define CONVERT_ET_LIST \
boolean, bf16, f16, f32, f64, i4, i8, i16, i32, i64, u1, u2, u3, u4, u6, u8, u16, u32, u64, nf4, f8e4m3, f8e5m2

#define CONVERT_TO_ANY_NO_NF4 \
boolean, bf16, f16, f32, f64, i4, i8, i16, i32, i64, u1, u2, u3, u4, u6, u8, u16, u32, u64, f8e4m3, f8e5m2

struct Evaluate : public element::NoAction<bool> {
using element::NoAction<bool>::visit;

template <element::Type_t ET_IN, class TI = fundamental_type_for<ET_IN>>
// convert from any (except F16, NF4) to any except NF4
template <element::Type_t ET_IN,
class TI = fundamental_type_for<ET_IN>,
typename std::enable_if<ET_IN != element::f16 && ET_IN != element::nf4>::type* = nullptr>
static result_type visit(const Tensor& arg, Tensor& out, const size_t count) {
using namespace ov::element;
return IF_TYPE_OF(Convert_out,
CONVERT_TO_ANY_NO_NF4,
EvalByOutputType,
out.get_element_type(),
iterator<ET_IN>(reinterpret_cast<const TI*>(arg.data())),
out,
count);
}

// convert from F16 to any
template <element::Type_t ET_IN,
class TI = fundamental_type_for<ET_IN>,
typename std::enable_if<ET_IN == element::f16>::type* = nullptr>
static result_type visit(const Tensor& arg, Tensor& out, const size_t count) {
using namespace ov::element;
return IF_TYPE_OF(Convert_out,
Expand All @@ -33,6 +55,21 @@ struct Evaluate : public element::NoAction<bool> {
count);
}

// convert form NF4
template <element::Type_t ET_IN,
class TI = fundamental_type_for<ET_IN>,
typename std::enable_if<ET_IN == element::nf4>::type* = nullptr>
static result_type visit(const Tensor& arg, Tensor& out, const size_t count) {
using namespace ov::element;
return IF_TYPE_OF(Convert_out,
OV_PP_ET_LIST(f16, f32, nf4),
EvalByOutputType,
out.get_element_type(),
iterator<ET_IN>(reinterpret_cast<const TI*>(arg.data())),
out,
count);
}

private:
struct EvalByOutputType : public element::NoAction<bool> {
using element::NoAction<bool>::visit;
Expand Down Expand Up @@ -118,16 +155,12 @@ bool Convert::evaluate(TensorVector& outputs, const TensorVector& inputs) const
if (auto& out = outputs[0]) {
const auto& in = inputs[0];
const auto& in_shape = in.get_shape();
const auto count = shape_size(in_shape);

out.set_shape(in_shape);

using namespace ov::element;
return IF_TYPE_OF(v0_Convert_in_et,
CONVERT_ET_LIST,
convert::Evaluate,
in.get_element_type(),
in,
out,
shape_size(in_shape));
return IF_TYPE_OF(v0_Convert_in_et, CONVERT_ET_LIST, convert::Evaluate, in.get_element_type(), in, out, count);
} else {
return false;
}
Expand All @@ -136,6 +169,10 @@ bool Convert::evaluate(TensorVector& outputs, const TensorVector& inputs) const
bool Convert::has_evaluate() const {
OV_OP_SCOPE(v0_Convert_has_evaluate);

const auto is_to_nf4_supported = [](const element::Type& from, const element::Type& to) {
return (from == element::nf4) && (to == element::f16 || to == element::f32 || to == element::nf4);
};

const auto is_valid_type = [](const element::Type& et) -> bool {
switch (et) {
case element::boolean:
Expand All @@ -154,7 +191,6 @@ bool Convert::has_evaluate() const {
case element::u16:
case element::u32:
case element::u64:
case element::nf4:
case element::f8e4m3:
case element::f8e5m2:
return true;
Expand All @@ -163,7 +199,10 @@ bool Convert::has_evaluate() const {
};
};

return is_valid_type(get_input_element_type(0)) && is_valid_type(get_output_element_type(0));
const auto& input_et = get_input_element_type(0);
const auto& output_et = get_output_element_type(0);

return (is_valid_type(input_et) && is_valid_type(output_et)) || is_to_nf4_supported(input_et, output_et);
}

bool Convert::evaluate_lower(TensorVector& output_values) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
if (decompression_subtract_type != DecompressionSubtractType::empty) {
auto subtract_shape = decompression_subtract_type == DecompressionSubtractType::full ? scaleshift_const_shape : Shape({});
auto shift_const = ov::test::utils::deprecated::make_constant<uint8_t>(weights_precision, subtract_shape, {}, true, up_to);
std::shared_ptr<ov::Node> shift_convert = std::make_shared<ov::op::v0::Convert>(shift_const, decompression_precision);
std::shared_ptr<ov::Node> shift_convert =
std::make_shared<ov::op::v0::Convert>(shift_const, decompression_precision);
if (reshape_on_decompression_constant) {
auto subtract_target_shape = decompression_subtract_type == DecompressionSubtractType::full
? scaleshift_target_shape : ov::Shape(scaleshift_const_shape.size(), 1);
Expand Down Expand Up @@ -343,7 +344,7 @@ const std::vector<ov::test::ElementType> decompression_precisions = {ov::element
const std::vector<ov::test::ElementType> weights_precisions = {ov::element::u8,
ov::element::u4,
ov::element::i4,
ov::element::nf4};
element::nf4};

const std::vector<ShapeParams> input_shapes_basic = {
{{{-1, -1, -1}, {{1, 4, 16}, {10, 16, 16}}}, {16, 32}},
Expand Down
159 changes: 81 additions & 78 deletions src/plugins/template/backend/ops/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,121 +7,124 @@
#include "evaluate_node.hpp"
#include "openvino/core/type/element_iterator.hpp"

namespace convert_like_v1 {
namespace convert {
template <ov::element::Type_t ti, ov::element::Type_t to>
inline void evaluate(const std::shared_ptr<ov::op::v1::ConvertLike>& op,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
using T_I = typename ov::element_type_traits<ti>::value_type;
using T_O = typename ov::element_type_traits<to>::value_type;
bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) {
outputs[0].set_shape(inputs[0].get_shape());
size_t element_count = ov::shape_size(outputs[0].get_shape());
const auto element_count = ov::shape_size(outputs[0].get_shape());

ov::reference::convert(ov::element::iterator<ti>(inputs[0].data<const T_I>()),
ov::element::iterator<to>(outputs[0].data<T_O>()),
ov::reference::convert(ov::element::iterator<ti>(static_cast<const void*>(inputs[0].data())),
ov::element::iterator<to>(outputs[0].data()),
element_count);
return true;
}
} // namespace convert_like_v1

template <ov::element::Type_t OUT_ET>
bool evaluate(const std::shared_ptr<ov::op::v1::ConvertLike>& op,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
bool evaluate_by_input_type(ov::TensorVector& outputs, const ov::TensorVector& inputs) {
switch (inputs[0].get_element_type()) {
case ov::element::boolean:
convert_like_v1::evaluate<ov::element::boolean, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::boolean, OUT_ET>(outputs, inputs);
case ov::element::u1:
convert_like_v1::evaluate<ov::element::u1, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::u1, OUT_ET>(outputs, inputs);
case ov::element::u4:
convert_like_v1::evaluate<ov::element::u4, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::u4, OUT_ET>(outputs, inputs);
case ov::element::u8:
convert_like_v1::evaluate<ov::element::u8, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::u8, OUT_ET>(outputs, inputs);
case ov::element::u16:
convert_like_v1::evaluate<ov::element::u16, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::u16, OUT_ET>(outputs, inputs);
case ov::element::u32:
convert_like_v1::evaluate<ov::element::u32, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::u32, OUT_ET>(outputs, inputs);
case ov::element::u64:
convert_like_v1::evaluate<ov::element::u64, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::u64, OUT_ET>(outputs, inputs);
case ov::element::i4:
convert_like_v1::evaluate<ov::element::i4, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::i4, OUT_ET>(outputs, inputs);
case ov::element::i8:
convert_like_v1::evaluate<ov::element::i8, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::i8, OUT_ET>(outputs, inputs);
case ov::element::i16:
convert_like_v1::evaluate<ov::element::i16, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::i16, OUT_ET>(outputs, inputs);
case ov::element::i32:
convert_like_v1::evaluate<ov::element::i32, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::i32, OUT_ET>(outputs, inputs);
case ov::element::i64:
convert_like_v1::evaluate<ov::element::i64, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::i64, OUT_ET>(outputs, inputs);
case ov::element::bf16:
convert_like_v1::evaluate<ov::element::bf16, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::bf16, OUT_ET>(outputs, inputs);
case ov::element::f16:
convert_like_v1::evaluate<ov::element::f16, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::f16, OUT_ET>(outputs, inputs);
case ov::element::f32:
convert_like_v1::evaluate<ov::element::f32, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::f32, OUT_ET>(outputs, inputs);
case ov::element::f64:
convert_like_v1::evaluate<ov::element::f64, OUT_ET>(op, outputs, inputs);
break;
return evaluate<ov::element::f64, OUT_ET>(outputs, inputs);
case ov::element::nf4:
return evaluate<ov::element::nf4, OUT_ET>(outputs, inputs);
default:
return false;
}
return true;
}

template <>
bool evaluate_node<ov::op::v1::ConvertLike>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
const auto& element_type = node->get_output_element_type(0);

switch (element_type) {
namespace {
bool evaluate_by_output_type(const ov::element::Type& output_et,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
switch (output_et) {
case ov::element::boolean:
return evaluate<ov::element::boolean>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
case ov::element::bf16:
return evaluate<ov::element::bf16>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
case ov::element::f16:
return evaluate<ov::element::f16>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
case ov::element::f64:
return evaluate<ov::element::f64>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
case ov::element::f32:
return evaluate<ov::element::f32>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
case ov::element::i4:
return evaluate<ov::element::i4>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
case ov::element::i8:
return evaluate<ov::element::i8>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
case ov::element::i16:
return evaluate<ov::element::i16>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
case ov::element::i32:
return evaluate<ov::element::i32>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
case ov::element::i64:
return evaluate<ov::element::i64>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
return evaluate_by_input_type<ov::element::boolean>(outputs, inputs);
case ov::element::u1:
return evaluate<ov::element::u1>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
return evaluate_by_input_type<ov::element::u1>(outputs, inputs);
case ov::element::u4:
return evaluate<ov::element::u4>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
return evaluate_by_input_type<ov::element::u4>(outputs, inputs);
case ov::element::u8:
return evaluate<ov::element::u8>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
return evaluate_by_input_type<ov::element::u8>(outputs, inputs);
case ov::element::u16:
return evaluate<ov::element::u16>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
return evaluate_by_input_type<ov::element::u16>(outputs, inputs);
case ov::element::u32:
return evaluate<ov::element::u32>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
return evaluate_by_input_type<ov::element::u32>(outputs, inputs);
case ov::element::u64:
return evaluate<ov::element::u64>(ov::as_type_ptr<ov::op::v1::ConvertLike>(node), outputs, inputs);
return evaluate_by_input_type<ov::element::u64>(outputs, inputs);
case ov::element::i4:
return evaluate_by_input_type<ov::element::i4>(outputs, inputs);
case ov::element::i8:
return evaluate_by_input_type<ov::element::i8>(outputs, inputs);
case ov::element::i16:
return evaluate_by_input_type<ov::element::i16>(outputs, inputs);
case ov::element::i32:
return evaluate_by_input_type<ov::element::i32>(outputs, inputs);
case ov::element::i64:
return evaluate_by_input_type<ov::element::i64>(outputs, inputs);
case ov::element::bf16:
return evaluate_by_input_type<ov::element::bf16>(outputs, inputs);
case ov::element::f16:
return evaluate_by_input_type<ov::element::f16>(outputs, inputs);
case ov::element::f32:
return evaluate_by_input_type<ov::element::f32>(outputs, inputs);
case ov::element::f64:
return evaluate_by_input_type<ov::element::f64>(outputs, inputs);
case ov::element::nf4:
return evaluate_by_input_type<ov::element::nf4>(outputs, inputs);
default:
return false;
}
}
} // namespace
} // namespace convert

template <>
bool evaluate_node<ov::op::v1::ConvertLike>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
if (convert::evaluate_by_output_type(node->get_output_element_type(0), outputs, inputs)) {
return true;
} else {
OPENVINO_THROW("Unhandled data type ", node->get_element_type().get_type_name(), " in evaluate_node()");
}
}

template <>
bool evaluate_node<ov::op::v0::Convert>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
if (convert::evaluate_by_output_type(node->get_output_element_type(0), outputs, inputs)) {
return true;
} else {
OPENVINO_THROW("Unhandled data type ", node->get_element_type().get_type_name(), " in evaluate_node()");
}
}
4 changes: 4 additions & 0 deletions src/plugins/template/backend/ops/ops_evaluates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ extern template bool evaluate_node<ov::op::v0::Ceiling>(std::shared_ptr<ov::Node
ov::TensorVector& outputs,
const ov::TensorVector& inputs);

extern template bool evaluate_node<ov::op::v0::Convert>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);

extern template bool evaluate_node<ov::op::v0::CTCGreedyDecoder>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);
Expand Down
Loading

0 comments on commit d199994

Please sign in to comment.