Skip to content

Commit

Permalink
Merge branch 'feature/identity_cpu' of https://github.com/PiotrKrzem/…
Browse files Browse the repository at this point in the history
…openvino into feature/identity_cpu
  • Loading branch information
PiotrKrzem committed Nov 28, 2024
2 parents 8157b2c + 64ddbf2 commit 9479d1b
Show file tree
Hide file tree
Showing 94 changed files with 2,211 additions and 1,022 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/merge_queue_stub.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
on:
merge_group:

jobs:
merge_group_stub_check:
name: ci/jenkins
runs-on: ubuntu-latest
defaults:
run:
shell: bash
if: ${{ github.event_name == 'merge_group' }}
steps:
- run: echo "Just a stub check to keep Jenkins running in pre-commits but not in merge queue"
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ The transformation function is a function that takes a sample from the dataset a
:language: python
:fragment: [dataset]

.. tab-item:: TorchFX
:sync: torch_fx

.. doxygensnippet:: docs/optimization_guide/nncf/ptq/code/ptq_torch_fx.py
:language: python
:fragment: [dataset]

If there is no framework dataset object, you can create your own entity that implements the ``Iterable`` interface in Python, for example the list of images, and returns data samples feasible for inference. In this case, a transformation function is not required.


Expand Down Expand Up @@ -102,6 +109,12 @@ See the `example section <#examples-of-how-to-apply-nncf-post-training-quantizat
:language: python
:fragment: [quantization]

.. tab-item:: TorchFX
:sync: torch_fx

.. doxygensnippet:: docs/optimization_guide/nncf/ptq/code/ptq_torch_fx.py
:language: python
:fragment: [quantization]

After that the model can be converted into the OpenVINO Intermediate Representation (IR) if needed, compiled and run with OpenVINO.
If you have not already installed OpenVINO developer tools, install it with ``pip install openvino``.
Expand Down Expand Up @@ -136,6 +149,17 @@ If you have not already installed OpenVINO developer tools, install it with ``pi
:language: python
:fragment: [inference]

TorchFX models can utilize OpenVINO optimizations using `torch.compile(..., backend="openvino") <https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html>`__ functionality:

.. tab-set::

.. tab-item:: TorchFX
:sync: torch_fx

.. doxygensnippet:: docs/optimization_guide/nncf/ptq/code/ptq_torch_fx.py
:language: python
:fragment: [inference]

Tune quantization parameters
############################

Expand Down
44 changes: 44 additions & 0 deletions docs/optimization_guide/nncf/ptq/code/ptq_torch_fx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

#! [dataset]
import nncf
import torch

calibration_loader = torch.utils.data.DataLoader(...)

def transform_fn(data_item):
images, _ = data_item
return images

calibration_dataset = nncf.Dataset(calibration_loader, transform_fn)
#! [dataset]

#! [quantization]
import torchvision
from nncf.torch import disable_patching

input_fp32 = torch.ones((1, 3, 224, 224)) # FP32 model input
model = torchvision.models.resnet50(pretrained=True)

with disable_patching():
exported_model = torch.export.export_for_training(model, args=(input_fp32,)).module()
quantized_model = nncf.quantize(exported_model, calibration_dataset)
#! [quantization]

#! [inference]
import openvino.torch

input_fp32 = ... # FP32 model input

# compile quantized model using torch.compile API
with disable_patching():
compiled_model_int8 = torch.compile(quantized_model, backend="openvino")
# OpenVINO backend compiles the model during the first call,
# so the first call is expected to be slower than the following calls
res = compiled_model_int8(input_fp32)

# save the model
exported_program = torch.export.export(quantized_model, args=(input_fp32,))
torch.export.save(exported_program, 'exported_program.pt2')
#! [inference]
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_info.hpp"
#include "snippets/lowered/pass/runtime_optimizer.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
/**
* @class MHAParallelWAOptimizer
* @brief Optimizes the dynamic MHA execution increasing parallel work amount dy dividing Brgemm's "M" dimension to "parallel_m"
* and "kernel_m". Uses heuristics from snippets::pass::SplitDimensionM for dimension splitting.
* The optimizer performs the following steps:
* - Identifies applicable Brgemm operations within the LinearIR.
* - Finds parameters whose shapes and layouts need to be adjusted after the split.
* - Determines loops that should be adjusted.
*/
class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer {
public:
MHAParallelWAOptimizer() = default;
MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, const RuntimeConfigurator* configurator);

bool run(const lowered::LinearIR& linear_ir) override;
bool applicable() const override { return !m_loops_to_split.empty(); }

private:
static std::unordered_set<lowered::ExpressionPtr> find_applicable_brgemms(const lowered::LinearIRCPtr& linear_ir);
static std::unordered_set<size_t> find_unsqueezed_params(
const lowered::LinearIRCPtr& linear_ir,
const std::unordered_set<lowered::ExpressionPtr>& brgemms);
static std::vector<lowered::ExpandedLoopInfoPtr> find_loops_to_split(
const lowered::LinearIRCPtr& linear_ir,
const std::unordered_set<size_t>& unsqueezed_params);

std::vector<lowered::ExpandedLoopInfoPtr> m_loops_to_split{};
std::unordered_set<size_t> m_unsqueezed_params{};
std::vector<std::vector<size_t>> m_optimized_layouts{};
std::vector<size_t> m_dim_M_idces{};
size_t m_concurrency = 0;

static const size_t m_dim_M_idx;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
16 changes: 16 additions & 0 deletions src/common/snippets/include/snippets/lowered/pass/pass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ class Pass : public PassBase {
virtual bool run(lowered::LinearIR& linear_ir) = 0;
};

/**
* @interface ConstPass
* @brief Base class for LIR passes which are performed on a full LIR body but doesn't change it
* @ingroup snippets
*/
class ConstPass : public PassBase {
public:
/**
* @brief Apply the pass to the Linear IR
* @param linear_ir the target Linear IR
* @return status of the pass
*/
virtual bool run(const lowered::LinearIR& linear_ir) = 0;
};

/**
* @interface RangedPass
* @brief Base class for LIR passes which are performed on a range of a LIR body
Expand Down Expand Up @@ -114,6 +129,7 @@ class PassPipeline {
void register_positioned_passes(const std::vector<PositionedPassLowered>& pos_passes);

void run(lowered::LinearIR& linear_ir) const;
void run(const lowered::LinearIR& linear_ir) const;
void run(lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) const;

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/runtime_configurator.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
/**
* @class RuntimeOptimizer
* @brief Base class for runtime optimizers that operate on LinearIR and RuntimeConfigurator during
* RuntimeConfigurator::update stage.
*/
class RuntimeOptimizer : public ConstPass {
public:
RuntimeOptimizer() = default;
RuntimeOptimizer(const RuntimeConfigurator* configurator) : m_configurator(configurator) {
OPENVINO_ASSERT(configurator, "RuntimeConfigurator musn't be nullptr");
}
/**
* @brief Defines if this pass is applicable. If it is not applicable, its registration in pass pipeline can be skipped.
*/
virtual bool applicable() const = 0;

/**
* @brief Creates an instance of the specified pass type and checks if it is applicable.
* If the pass is applicable, it is registered in the provided pipeline.
* @param pipeline The pipeline in which the pass should be registered.
* @param args The arguments to be forwarded to the pass constructor.
*/
template <typename OptimizerType, typename... Args, typename = std::enable_if<std::is_base_of<RuntimeOptimizer, OptimizerType>::value>>
static void register_if_applicable(PassPipeline& pipeline, Args&&... args) {
auto pass = std::make_shared<OptimizerType>(std::forward<Args>(args)...);
if (pass->applicable()) {
pipeline.register_pass(pass);
}
}

protected:
const RuntimeConfigurator* m_configurator = nullptr;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ namespace pass {
* @brief Base class for LinearIR serialization passes
* @ingroup snippets
*/
class SerializeBase : public Pass {
class SerializeBase : public ConstPass {
public:
OPENVINO_RTTI("SerializeBase", "Pass")
OPENVINO_RTTI("SerializeBase", "ConstPass")
SerializeBase(const std::string& xml_path);

protected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,7 @@ class SerializeControlFlow : public SerializeBase {
OPENVINO_RTTI("SerializeControlFlow", "Pass", SerializeBase)
SerializeControlFlow(const std::string& xml_path, bool update_dynamic_ops = false) :
SerializeBase(xml_path), m_update_dynamic_ops{update_dynamic_ops} {}

bool run(LinearIR& linear_ir) override {
return run(const_cast<const LinearIR&>(linear_ir));
}
// We need a const method to run from functions that can't change LIR
bool run(const LinearIR& linear_ir);
bool run(const LinearIR& linear_ir) override;

private:
const bool m_update_dynamic_ops = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@ class SerializeDataFlow : public SerializeBase {
public:
OPENVINO_RTTI("SerializeDataFlow", "Pass", SerializeBase)
SerializeDataFlow(const std::string& xml_path) : SerializeBase(xml_path) {}

bool run(LinearIR& linear_ir) override {
return run(const_cast<const LinearIR&>(linear_ir));
}
// We need a const method to run from functions that can't change LIR
bool run(const LinearIR& linear_ir);
bool run(const LinearIR& linear_ir) override;
};

} // namespace pass
Expand Down
Loading

0 comments on commit 9479d1b

Please sign in to comment.