Skip to content

Commit

Permalink
Rest review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Nov 20, 2024
1 parent 62e44fe commit 68b8b06
Show file tree
Hide file tree
Showing 21 changed files with 123 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ namespace pass {
class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer {
public:
MHAParallelWAOptimizer() = default;
MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, RuntimeConfigurator* configurator);
MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, const RuntimeConfigurator* configurator);

bool run(const lowered::LinearIR& linear_ir) override;
bool applicable() const override { return !m_loops_to_split.empty(); }

private:
static std::unordered_set<lowered::ExpressionPtr> find_applicable_brgemms(const lowered::LinearIRCPtr& linear_ir);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,30 @@ namespace pass {
class RuntimeOptimizer : public ConstPass {
public:
RuntimeOptimizer() = default;
RuntimeOptimizer(RuntimeConfigurator* configurator) : m_configurator(configurator) {}
RuntimeOptimizer(const RuntimeConfigurator* configurator) : m_configurator(configurator) {
OPENVINO_ASSERT(configurator, "RuntimeConfigurator musn't be nullptr");
}
/**
* @brief Defines if this pass is applicable. If it is not applicable, its registration in pass pipeline can be skipped.
*/
virtual bool applicable() const = 0;

/**
* @brief Creates an instance of the specified pass type and checks if it is applicable.
* If the pass is applicable, it is registered in the provided pipeline.
* @param pipeline The pipeline in which the pass should be registered.
* @param args The arguments to be forwarded to the pass constructor.
*/
template <typename OptimizerType, typename... Args, typename = std::enable_if<std::is_base_of<RuntimeOptimizer, OptimizerType>::value>>
static void register_if_applicable(PassPipeline& pipeline, Args&&... args) {
auto pass = std::make_shared<OptimizerType>(std::forward<Args>(args)...);
if (pass->applicable()) {
pipeline.register_pass(pass);
}
}

protected:
RuntimeConfigurator* m_configurator = nullptr;
const RuntimeConfigurator* m_configurator = nullptr;
};

} // namespace pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class RuntimeConfigurator {
* @brief Update tensor rank based on master shape
* @param master_shape Master shape
*/
virtual void update_tensor_rank(const ov::snippets::VectorDims& master_shape);
virtual void update_tensor_rank(const ov::snippets::VectorDims& master_shape) const;

protected:
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using namespace ov::snippets::pass;

const size_t MHAParallelWAOptimizer::m_dim_M_idx = 1;

MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, RuntimeConfigurator* configurator)
MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, const RuntimeConfigurator* configurator)
: lowered::pass::RuntimeOptimizer(configurator) {
if (linear_ir->get_config().m_enable_domain_optimization || !linear_ir->is_dynamic())
return;
Expand Down Expand Up @@ -47,9 +47,6 @@ MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& line

bool MHAParallelWAOptimizer::run(const lowered::LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::MHAParallelWAOptimizer")
if (m_loops_to_split.empty())
return false;

const auto& config = m_configurator->get_config();
size_t new_batch_dim, new_kernel_dim;
if (!SplitDimensionM::split(config->master_shape, m_concurrency, new_batch_dim, new_kernel_dim))
Expand Down
5 changes: 3 additions & 2 deletions src/common/snippets/src/runtime_configurator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace snippets {

using namespace ov::snippets::pass;
using namespace ov::snippets::lowered;
using namespace ov::snippets::lowered::pass;

#ifdef SNIPPETS_DEBUG_CAPS
std::string RuntimeConfig::to_string() const {
Expand Down Expand Up @@ -65,7 +66,7 @@ void RuntimeConfigurator::initialization(const lowered::LinearIRCPtr& linear_ir)
m_config->tile_rank = linear_ir->get_config().m_loop_depth;

if (linear_ir->is_dynamic())
m_intermediate_optimizers.register_pass<ov::snippets::lowered::pass::MHAParallelWAOptimizer>(linear_ir, this);
RuntimeOptimizer::register_if_applicable<MHAParallelWAOptimizer>(m_intermediate_optimizers, linear_ir, this);
}

void RuntimeConfigurator::update(const lowered::LinearIRCPtr& linear_ir) {
Expand All @@ -86,7 +87,7 @@ void RuntimeConfigurator::update(const lowered::LinearIRCPtr& linear_ir) {
m_config->latest_shapes = std::move(m_config->io_shapes);
}

void RuntimeConfigurator::update_tensor_rank(const ov::snippets::VectorDims& master_shape) {
void RuntimeConfigurator::update_tensor_rank(const ov::snippets::VectorDims& master_shape) const {
m_config->tensor_rank = master_shape.size();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#endif
namespace ov {
namespace intel_cpu {
using namespace ov::snippets::lowered::pass;

const size_t CPURuntimeConfigurator::rank6D = 6;

Expand Down Expand Up @@ -44,8 +45,8 @@ void CPURuntimeConfigurator::initialization(const ov::snippets::lowered::LinearI
RuntimeConfigurator::initialization(linear_ir);
#ifndef OPENVINO_ARCH_ARM64
if (linear_ir->is_dynamic())
m_intermediate_optimizers.register_pass<BrgemmCopyBLoopPortsAdjuster>(linear_ir, this);
m_final_optimizers.register_pass<BrgemmExternalRepackingAdjuster>(linear_ir, this);
RuntimeOptimizer::register_if_applicable<BrgemmCopyBLoopPortsAdjuster>(m_intermediate_optimizers, linear_ir, this);
RuntimeOptimizer::register_if_applicable<BrgemmExternalRepackingAdjuster>(m_final_optimizers, linear_ir, this);
#endif
}

Expand All @@ -72,7 +73,7 @@ void CPURuntimeConfigurator::update(const ov::snippets::lowered::LinearIRCPtr& l
m_config->latest_shapes = std::move(m_config->io_shapes);
}

void CPURuntimeConfigurator::update_tensor_rank(const ov::snippets::VectorDims& master_shape) {
void CPURuntimeConfigurator::update_tensor_rank(const ov::snippets::VectorDims& master_shape) const {
m_config->tensor_rank = std::max(master_shape.size(), rank6D);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class CPURuntimeConfigurator : public ov::snippets::RuntimeConfigurator {
void update_loop_args(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const;
protected:
void update(const ov::snippets::lowered::LinearIRCPtr& linear_ir) override;
void update_tensor_rank(const ov::snippets::VectorDims& master_shape) override;
void update_tensor_rank(const ov::snippets::VectorDims& master_shape) const override;
void init_tensor_rank(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const override;
void initialization(const ov::snippets::lowered::LinearIRCPtr& linear_ir) override;

Expand Down
23 changes: 13 additions & 10 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
#include "transformations/snippets/x64/pass/lowered/insert_brgemm_copy_b_buffers.hpp"
#include "transformations/snippets/x64/pass/remove_converts.hpp"
#include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp"
#include "transformations/snippets/x64/pass/move_brgemm_repacking_out.hpp"
#include "transformations/snippets/x64/pass/eliminate_brgemm_copy_b.hpp"
#include "transformations/snippets/x64/pass/enforce_precision.hpp"
#include "transformations/snippets/x64/shape_inference.hpp"
#include "transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.hpp"
Expand Down Expand Up @@ -650,7 +650,7 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before, ov::snippets::pass::PropagatePrecision,
ov::intel_cpu::pass::BrgemmToBrgemmCPU);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, ov::intel_cpu::pass::BrgemmToBrgemmCPU,
ov::intel_cpu::pass::MoveBrgemmRepackingOut);
ov::intel_cpu::pass::EliminateBrgemmCopyB);
SNIPPETS_REGISTER_PASS_ABSOLUTE_X86_64(Place::PipelineEnd, ov::intel_cpu::pass::RemoveConverts);
SNIPPETS_REGISTER_PASS_ABSOLUTE_COMMON(Place::PipelineEnd, ov::intel_cpu::pass::MulAddToFMA);

Expand Down Expand Up @@ -992,26 +992,29 @@ void Subgraph::SubgraphExecutor::parallel_forNd(const std::function<void(jit_sni
});
}

void Subgraph::SubgraphExecutor::execute(dnnl::stream strm, std::vector<MemoryPtr>& inMemPtrs, std::vector<MemoryPtr>& outMemPtrs) {
if (m_in_requested_descs.empty())
void Subgraph::SubgraphExecutor::execute(dnnl::stream strm, const std::vector<MemoryPtr>& inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs) {
if (!m_in_requested_descs.empty()) {
auto reorderedInMemPtrs = exec_in_reorders(strm, inMemPtrs);
exec_impl(reorderedInMemPtrs, outMemPtrs);
} else {
exec_impl(inMemPtrs, outMemPtrs);
else
reorder_execute(strm, inMemPtrs, outMemPtrs);
}
}

void Subgraph::SubgraphExecutor::reorder_execute(dnnl::stream strm, std::vector<MemoryPtr> inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs) {
std::vector<MemoryPtr> Subgraph::SubgraphExecutor::exec_in_reorders(dnnl::stream strm, const std::vector<MemoryPtr>& inMemPtrs) {
auto reordered_in_ptrs = inMemPtrs;
size_t offset = m_internal_buffer_size;
for (const auto& requested_descs_elem : m_in_requested_descs) {
const auto in_idx = requested_descs_elem.first;
const auto& requested_desc = requested_descs_elem.second;

const void* data_ptr = m_buffer_scratchpad->getDataAs<uint8_t>() + offset;
const auto scratch_mem = std::make_shared<Memory>(strm.get_engine(), requested_desc, data_ptr, false);
scratch_mem->load(*inMemPtrs[in_idx]);
inMemPtrs[in_idx] = scratch_mem;
scratch_mem->load(*reordered_in_ptrs[in_idx]);
reordered_in_ptrs[in_idx] = scratch_mem;
offset += requested_desc->getCurrentMemSize();
}
exec_impl(inMemPtrs, outMemPtrs);
return reordered_in_ptrs;
}

} // namespace node
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/nodes/subgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class Subgraph::SubgraphExecutor {
const BufferScratchpadAllocator& allocator);
virtual ~SubgraphExecutor() = default;

void execute(dnnl::stream strm, std::vector<MemoryPtr>& inMemPtrs, std::vector<MemoryPtr>& outMemPtrs);
void execute(dnnl::stream strm, const std::vector<MemoryPtr>& inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs);

protected:
virtual void exec_impl(const std::vector<MemoryPtr>& inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs) = 0;
Expand Down Expand Up @@ -169,7 +169,7 @@ class Subgraph::SubgraphExecutor {
#endif

private:
void reorder_execute(dnnl::stream strm, std::vector<MemoryPtr> inMemPtrs, const std::vector<MemoryPtr>& outMemPtrs);
std::vector<MemoryPtr> exec_in_reorders(dnnl::stream strm, const std::vector<MemoryPtr>& inMemPtrs);

std::unordered_map<size_t, CpuBlockedMemoryDescPtr> m_in_requested_descs = {};
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ ov::snippets::lowered::ExpressionPtr get_copy_b_expr(const ov::snippets::lowered
} else if (ov::is_type<snippets::lowered::BufferExpression>(b_input_expr)) {
OPENVINO_ASSERT(b_input_expr->get_input_count() >= 1, "BufferExpression on brgemm's B input must have at least one input");
const auto input_buffer_expr = b_input_expr->get_input_port_connector(0)->get_source().get_expr();
if (ov::is_type<BrgemmCopyB>(b_input_expr->get_node())) {
if (ov::is_type<BrgemmCopyB>(input_buffer_expr->get_node())) {
return input_buffer_expr;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ enum class BRGEMM_TYPE {
STAND_ALONE, // No extra requirements, used for f32|f32
WITH_AMX, // i8|i8 or bf16|bf16 on AMX system - needs BrgemmCopyB and scratchpad
WITH_COMPENSATIONS, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations
REPACKING_ONLY, // low precision or some specific f32 cases - needs BrgemmCopyB on second input for data repacking
REPACKING_ONLY, // u8|i8, or bf16|bf16 (non-AMX system), or brgemm with transpose_b=true - needs BrgemmCopyB on second input for data repacking
};

dnnl::impl::cpu::x64::cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "move_brgemm_repacking_out.hpp"
#include "eliminate_brgemm_copy_b.hpp"

#include "cpu/x64/cpu_isa_traits.hpp"
#include "openvino/pass/pattern/matcher.hpp"
Expand All @@ -15,26 +15,26 @@
namespace ov {
namespace intel_cpu {

pass::MoveBrgemmRepackingOut::MoveBrgemmRepackingOut() {
MATCHER_SCOPE(MoveBrgemmRepackingOut);
pass::EliminateBrgemmCopyB::EliminateBrgemmCopyB() {
MATCHER_SCOPE(EliminateBrgemmCopyB);
auto m_param = ov::pass::pattern::wrap_type<ov::op::v0::Parameter>();
auto m_rank_norm = ov::pass::pattern::optional<ov::snippets::op::RankNormalization>(m_param);
auto m_copy_b = ov::pass::pattern::wrap_type<BrgemmCopyB>({m_param});

auto callback = [=](ov::pass::pattern::Matcher& m) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::MoveBrgemmRepackingOut")
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::EliminateBrgemmCopyB")
const auto& pattern_map = m.get_pattern_value_map();
const auto& copy_b_out = pattern_map.at(m_copy_b);
const auto copy_b_node = ov::as_type_ptr<BrgemmCopyB>(copy_b_out.get_node_shared_ptr());
OPENVINO_ASSERT(copy_b_node, "BrgemmCopyB node is null in MoveBrgemmRepackingOut transformation");
OPENVINO_ASSERT(copy_b_node, "BrgemmCopyB node is null in EliminateBrgemmCopyB transformation");

const auto& in_desc = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(copy_b_node->input(0));
const auto& layout = in_desc->get_layout();
// TODO:
// 1. Ticket 157340: support external repacking for copyB with compensations
// 2. Ticket 157339: support external repacking for non-planar layout
if (!ov::snippets::utils::is_planar_layout(layout) ||
copy_b_node->get_src_element_type() == ov::element::i8 || transformation_callback(copy_b_node))
brgemm_utils::with_compensations(copy_b_node->get_type()) || transformation_callback(copy_b_node))
return false;
return ov::replace_output_update_name(copy_b_out, copy_b_node->input_value(0));
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"

namespace ov {
namespace intel_cpu {
namespace pass {

/**
* @interface EliminateBrgemmCopyB
* @brief EliminateBrgemmCopyB identifies BrgemmCopyB nodes which can be inferred outside the Subgraph.
* If this is possible, CopyB node is removed, and the external repacking is configured on the further pipeline stages in RuntimeConfigurator.
*
* @ingroup snippets
*/
class EliminateBrgemmCopyB: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("EliminateBrgemmCopyB", "0");
EliminateBrgemmCopyB();
};


} // namespace pass
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -65,48 +65,35 @@ bool pass::AdjustBrgemmCopyBLoopPorts::run(const snippets::lowered::LinearIR& li

bool modified = false;

auto get_repacking_loop_idces = [](const snippets::lowered::ExpressionPtr& parent_expr) {
auto get_repacking_loop_idces = [](const snippets::lowered::ExpressionPtr& brgemm_expr) {
// Repacking may be extracted outside the snippets kernel. In this case, brgemm parent expression is a parameter.
if (is_type<ov::op::v0::Parameter>(parent_expr->get_node()))
if (is_type<ov::op::v0::Parameter>(brgemm_expr->get_input_port_connector(1)->get_source().get_expr()->get_node()))
return std::vector<size_t>{};

OPENVINO_ASSERT(is_type<snippets::lowered::BufferExpression>(parent_expr),
"In case of repacking brgemm expr must have BufferExpression on B input");
const auto buffer_parent_ports = parent_expr->get_input_port(0).get_connected_ports();
OPENVINO_ASSERT(buffer_parent_ports.size() == 1,
"Parent of brgemm repacking buffer must be connected only to the buffer");
const auto& repacking_expr = buffer_parent_ports.begin()->get_expr();
const auto repacking_expr = brgemm_utils::repacking::get_copy_b_expr(brgemm_expr);
OPENVINO_ASSERT(repacking_expr, "BrgemmCopyB expression is not found");
return repacking_expr->get_loop_ids();
};

for (const auto& expr : linear_ir) {
const auto brgemm = ov::as_type_ptr<BrgemmCPU>(expr->get_node());
if (!brgemm || !brgemm_utils::with_repacking(brgemm->get_type()))
continue;
const auto& parent_expr = expr->get_input_port_connector(1)->get_source().get_expr();
const auto& repacking_loop_ids = get_repacking_loop_idces(parent_expr);
for (const auto& target_port : parent_expr->get_output_port(0).get_connected_ports()) {
const auto& port_node = target_port.get_expr()->get_node();
if (!is_type<intel_cpu::BrgemmCPU>(port_node)) {
OPENVINO_ASSERT(is_type<snippets::op::LoopEnd>(port_node),
"Invalid grandchild of BrgemmCopyB");
continue;
}
const auto &brgemm_loop_ids = target_port.get_expr()->get_loop_ids();
// Continue if there is no blocking loop
if (brgemm_loop_ids.empty() && repacking_loop_ids.empty())
continue;
OPENVINO_ASSERT(brgemm_loop_ids.size() > repacking_loop_ids.size(), "Invalid BrgemmCopyB loop configuration");
const auto &loop_manager = linear_ir.get_loop_manager();
for (auto i = repacking_loop_ids.size(); i < brgemm_loop_ids.size(); i++) {
const auto &loop = loop_manager->get_loop_info(brgemm_loop_ids[i]);
auto uni_loop = ov::as_type_ptr<snippets::lowered::UnifiedLoopInfo>(loop);
if (!uni_loop)
uni_loop = ov::as_type_ptr<snippets::lowered::ExpandedLoopInfo>(loop)->get_unified_loop_info();
if (!m_affected_loops.count(uni_loop) && update_loop_info(uni_loop)) {
m_affected_loops.insert(uni_loop);
modified = true;
}
const auto& brgemm_loop_ids = expr->get_loop_ids();
const auto& repacking_loop_ids = get_repacking_loop_idces(expr);
// Continue if there is no blocking loop
if (brgemm_loop_ids.empty() && repacking_loop_ids.empty())
continue;

OPENVINO_ASSERT(brgemm_loop_ids.size() > repacking_loop_ids.size(), "Invalid BrgemmCopyB loop configuration");
const auto &loop_manager = linear_ir.get_loop_manager();
for (auto i = repacking_loop_ids.size(); i < brgemm_loop_ids.size(); i++) {
const auto &loop = loop_manager->get_loop_info(brgemm_loop_ids[i]);
auto uni_loop = ov::as_type_ptr<snippets::lowered::UnifiedLoopInfo>(loop);
if (!uni_loop)
uni_loop = ov::as_type_ptr<snippets::lowered::ExpandedLoopInfo>(loop)->get_unified_loop_info();
if (!m_affected_loops.count(uni_loop) && update_loop_info(uni_loop)) {
m_affected_loops.insert(uni_loop);
modified = true;
}
}
}
Expand Down
Loading

0 comments on commit 68b8b06

Please sign in to comment.