Skip to content

Commit

Permalink
Merge branch 'master' into yi3/sdpa_group_quant
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangYiIntel committed Nov 22, 2024
2 parents a4798d9 + aca1bb4 commit 9f13e4e
Show file tree
Hide file tree
Showing 39 changed files with 996 additions and 525 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ The transformation function is a function that takes a sample from the dataset a
:language: python
:fragment: [dataset]

.. tab-item:: TorchFX
:sync: torch_fx

.. doxygensnippet:: docs/optimization_guide/nncf/ptq/code/ptq_torch_fx.py
:language: python
:fragment: [dataset]

If there is no framework dataset object, you can create your own entity that implements the ``Iterable`` interface in Python, for example the list of images, and returns data samples feasible for inference. In this case, a transformation function is not required.


Expand Down Expand Up @@ -102,6 +109,12 @@ See the `example section <#examples-of-how-to-apply-nncf-post-training-quantizat
:language: python
:fragment: [quantization]

.. tab-item:: TorchFX
:sync: torch_fx

.. doxygensnippet:: docs/optimization_guide/nncf/ptq/code/ptq_torch_fx.py
:language: python
:fragment: [quantization]

After that the model can be converted into the OpenVINO Intermediate Representation (IR) if needed, compiled and run with OpenVINO.
If you have not already installed OpenVINO developer tools, install it with ``pip install openvino``.
Expand Down Expand Up @@ -136,6 +149,17 @@ If you have not already installed OpenVINO developer tools, install it with ``pi
:language: python
:fragment: [inference]

TorchFX models can utilize OpenVINO optimizations using `torch.compile(..., backend="openvino") <https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html>`__ functionality:

.. tab-set::

.. tab-item:: TorchFX
:sync: torch_fx

.. doxygensnippet:: docs/optimization_guide/nncf/ptq/code/ptq_torch_fx.py
:language: python
:fragment: [inference]

Tune quantization parameters
############################

Expand Down
44 changes: 44 additions & 0 deletions docs/optimization_guide/nncf/ptq/code/ptq_torch_fx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

#! [dataset]
import nncf
import torch

calibration_loader = torch.utils.data.DataLoader(...)

def transform_fn(data_item):
images, _ = data_item
return images

calibration_dataset = nncf.Dataset(calibration_loader, transform_fn)
#! [dataset]

#! [quantization]
import torchvision
from nncf.torch import disable_patching

input_fp32 = torch.ones((1, 3, 224, 224)) # FP32 model input
model = torchvision.models.resnet50(pretrained=True)

with disable_patching():
exported_model = torch.export.export_for_training(model, args=(input_fp32,)).module()
quantized_model = nncf.quantize(exported_model, calibration_dataset)
#! [quantization]

#! [inference]
import openvino.torch

input_fp32 = ... # FP32 model input

# compile quantized model using torch.compile API
with disable_patching():
compiled_model_int8 = torch.compile(quantized_model, backend="openvino")
# OpenVINO backend compiles the model during the first call,
# so the first call is expected to be slower than the following calls
res = compiled_model_int8(input_fp32)

# save the model
exported_program = torch.export.export(quantized_model, args=(input_fp32,))
torch.export.save(exported_program, 'exported_program.pt2')
#! [inference]
8 changes: 0 additions & 8 deletions src/common/snippets/tests/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,6 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM) {
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) {
#if defined(WIN32) && !defined(NDEBUG)
test_skipped = true;
GTEST_SKIP() << "Skipping on Windows in Debug mode due to Issue 155258.";
#endif
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{2, 64, 12, 64}, {128, 12, 1, 64}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}},
Expand All @@ -195,10 +191,6 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) {
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) {
#if defined(WIN32) && !defined(NDEBUG)
test_skipped = true;
GTEST_SKIP() << "Skipping on Windows in Debug mode due to Issue 155258.";
#endif
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{1, 6, 64, 16, 64}, {1, 384, 16, 1, 64}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
Expand Down
4 changes: 4 additions & 0 deletions src/common/transformations/include/ov_ops/rms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class TRANSFORMATIONS_API RMS : public ov::op::Op {
m_epsilon = epsilon;
}

void set_output_type_attr(const element::Type& output_type) {
m_output_type = output_type;
}

private:
double m_epsilon{0};
ov::element::Type m_output_type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,19 @@
#pragma once

#include "openvino/op/op.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace intel_gpu {
namespace op {
namespace internal {

/// \brief Operator performing Swish Gated Linear Unit Activation
/// This operation performs gated linear unit activation that combines swish or gelu activation function
class SwiGLU : public ov::op::Op {
class TRANSFORMATIONS_API SwiGLU : public ov::op::Op {
public:
OPENVINO_OP("SwiGLU", "gpu_opset");
OPENVINO_OP("SwiGLU", "ie_internal_opset");

enum GluType {
Swish = 0,
Gelu,
Gelu_Tanh
};
enum GluType { Swish = 0, Gelu, Gelu_Tanh };

SwiGLU() = default;
/// \brief Constructs an SwiGLU operation.
Expand All @@ -44,26 +41,44 @@ class SwiGLU : public ov::op::Op {

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

int64_t get_axis() const { return m_axis; }
int64_t get_split_lengths() const { return m_split_lengths; }
GluType get_glu_type() const { return m_glu_type; }
size_t get_split_to_glu_idx() const { return m_split_to_glu_idx; }
int64_t get_axis() const {
return m_axis;
}
int64_t get_split_lengths() const {
return m_split_lengths;
}
GluType get_glu_type() const {
return m_glu_type;
}
size_t get_split_to_glu_idx() const {
return m_split_to_glu_idx;
}

void set_axis(int64_t axis) { m_axis = axis; }
void set_split_lengths(int64_t split_lengths) { m_split_lengths = split_lengths; }
void set_glu_type(GluType glu_type) { m_glu_type = glu_type; }
void set_split_to_glu_idx(size_t split_to_glu_idx) { m_split_to_glu_idx = split_to_glu_idx; }
void set_axis(int64_t axis) {
m_axis = axis;
}
void set_split_lengths(int64_t split_lengths) {
m_split_lengths = split_lengths;
}
void set_glu_type(GluType glu_type) {
m_glu_type = glu_type;
}
void set_split_to_glu_idx(size_t split_to_glu_idx) {
m_split_to_glu_idx = split_to_glu_idx;
}

private:
int64_t m_axis = 0;
int64_t m_split_lengths = 0;
GluType m_glu_type = GluType::Swish;
size_t m_split_to_glu_idx = 0;
ov::element::Type m_output_type;
ov::element::Type m_output_type{};
};

std::vector<ov::PartialShape> shape_infer(const SwiGLU* op, std::vector<ov::PartialShape> input_shapes);
// TODO 157615: Move to shape_inference
TRANSFORMATIONS_API std::vector<ov::PartialShape> shape_infer(const SwiGLU* op,
std::vector<ov::PartialShape> input_shapes);

} // namespace op
} // namespace intel_gpu
} // namespace ov
} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/manager.hpp"
#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API SwiGLUFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("SwiGLUFusion", "0");
SwiGLUFusion();
};

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,29 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "intel_gpu/op/swiglu.hpp"
#include "ov_ops/swiglu.hpp"

#include "openvino/core/partial_shape.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/variadic_split.hpp"
#include "variadic_split_shape_inference.hpp"

namespace ov {
namespace intel_gpu {
namespace op {
namespace internal {

SwiGLU::SwiGLU(const Output<Node>& data,
int64_t axis,
int64_t split_lengths,
const GluType glu_type,
const size_t split_to_glu_idx,
const ov::element::Type output_type)
: Op({data}), m_axis(axis), m_split_lengths(split_lengths),
m_glu_type(glu_type), m_split_to_glu_idx(split_to_glu_idx), m_output_type(output_type) {
: Op({data}),
m_axis(axis),
m_split_lengths(split_lengths),
m_glu_type(glu_type),
m_split_to_glu_idx(split_to_glu_idx),
m_output_type(output_type) {
validate_and_infer_types();
}

Expand All @@ -33,11 +38,9 @@ bool SwiGLU::visit_attributes(ov::AttributeVisitor& visitor) {
void SwiGLU::validate_and_infer_types() {
auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type;

std::vector<ov::PartialShape> input_shapes = {
get_input_partial_shape(0),
ov::PartialShape(ov::Shape{}),
ov::PartialShape(ov::Shape{2})
};
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0),
ov::PartialShape(ov::Shape{}),
ov::PartialShape(ov::Shape{2})};

set_output_type(0, output_type, shape_infer(this, input_shapes)[0]);
}
Expand All @@ -54,16 +57,18 @@ std::shared_ptr<Node> SwiGLU::clone_with_new_inputs(const ov::OutputVector& new_

std::vector<ov::PartialShape> shape_infer(const SwiGLU* op, std::vector<ov::PartialShape> input_shapes) {
ov::op::v1::VariadicSplit variadic_split;
std::vector<int64_t> axis = { op->get_axis() };
std::vector<int64_t> split_lengths = { op->get_split_lengths(), -1 };
std::vector<int64_t> axis = {op->get_axis()};
std::vector<int64_t> split_lengths = {op->get_split_lengths(), -1};

std::unordered_map<size_t, ov::Tensor> const_data;
const_data.emplace(1, ov::Tensor(ov::element::i64, ov::Shape{}, static_cast<void*>(axis.data())));
const_data.emplace(2, ov::Tensor(ov::element::i64, ov::Shape{split_lengths.size()}, static_cast<void*>(split_lengths.data())));
const_data.emplace(
2,
ov::Tensor(ov::element::i64, ov::Shape{split_lengths.size()}, static_cast<void*>(split_lengths.data())));

return ov::op::v1::shape_infer(&variadic_split, input_shapes, ov::make_tensor_accessor(const_data));
}

} // namespace internal
} // namespace op
} // namespace intel_gpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "openvino/op/reduce_mean.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "ov_ops/rms.hpp"
#include "transformations/utils/utils.hpp"
Expand Down Expand Up @@ -43,29 +45,38 @@ RMSFusion::RMSFusion(bool force_tail_convert) {

// x^2
auto const_power = wrap_type<ov::op::v0::Constant>(constant_value(2));
auto power = wrap_type<ov::op::v1::Power>({x, const_power});
auto const_power_convert = pattern::optional<ov::op::v0::Convert>(const_power);
auto power = wrap_type<ov::op::v1::Power>({x, const_power_convert});

// ReduceMean(x^2,axes)
auto mean_axes = wrap_type<ov::op::v0::Constant>(constant_value(-1));
auto mean = wrap_type<ov::op::v1::ReduceMean>({power, mean_axes});

// ReduceMean(x^2,axes)+eps
auto eps = wrap_type<ov::op::v0::Constant>();
auto add_eps = wrap_type<ov::op::v1::Add>({mean, eps});
auto eps_convert = pattern::optional<ov::op::v0::Convert>(eps);
auto add_eps = wrap_type<ov::op::v1::Add>({mean, eps_convert});

// Sqrt(ReduceMean(x^2,axes)+eps)
auto sqrt = wrap_type<ov::op::v0::Sqrt>({add_eps});

// 1/Sqrt(ReduceMean(x^2,axes)+eps)
auto const_div = wrap_type<ov::op::v0::Constant>(constant_value(-1));
auto div = wrap_type<ov::op::v1::Power>({sqrt, const_div});
auto const_pow = wrap_type<ov::op::v0::Constant>(constant_value(-1));
auto const_pow_convert = pattern::optional<ov::op::v0::Convert>(const_pow);
auto pow = wrap_type<ov::op::v1::Power>({sqrt, const_pow_convert});

auto const_div = wrap_type<ov::op::v0::Constant>(constant_value(1));
auto const_div_convert = pattern::optional<ov::op::v0::Convert>(const_div);
auto div = wrap_type<ov::op::v1::Divide>({const_div_convert, sqrt});
auto div_or_pow = std::make_shared<pattern::op::Or>(OutputVector{div, pow});

// x * 1/Sqrt(ReduceMean(x^2,axes)+eps)
auto mul1 = wrap_type<ov::op::v1::Multiply>({x, div});
auto mul1 = wrap_type<ov::op::v1::Multiply>({x, div_or_pow});

// x * 1/Sqrt(ReduceMean(x^2,axes)+eps) * gamma
auto gamma = wrap_type<ov::op::v0::Constant>(type_matches(element::f32));
auto mul2 = wrap_type<ov::op::v1::Multiply>({gamma, mul1});
auto gamma = wrap_type<ov::op::v0::Constant>();
auto gamma_convert = pattern::optional<ov::op::v0::Convert>(gamma);
auto mul2 = wrap_type<ov::op::v1::Multiply>({gamma_convert, mul1});

std::shared_ptr<ov::Node> comp = mul2;
if (force_tail_convert) {
Expand All @@ -88,7 +99,10 @@ RMSFusion::RMSFusion(bool force_tail_convert) {
return false;
}

const auto& gamma_node = pattern_map.at(gamma).get_node_shared_ptr();
auto gamma_node = pattern_map.at(gamma).get_node_shared_ptr();
if (pattern_map.find(gamma_convert) != pattern_map.end()) {
gamma_node = pattern_map.at(gamma_convert).get_node_shared_ptr();
}

const auto& mean_node = pattern_map.at(mean).get_node_shared_ptr();
const auto& axes = pattern_map.at(mean_axes).get_node_shared_ptr();
Expand Down
Loading

0 comments on commit 9f13e4e

Please sign in to comment.