Skip to content

Commit

Permalink
[CPU] Fuse SDPA before/after Reshape+Transpose Node to SDPA (openvino…
Browse files Browse the repository at this point in the history
…toolkit#26819)

### Details:
- *Pattern: QKV_Reshape -> QKV_Transpose ->
SDPA->OUT_Transpse->OUT_Reshape*
 - *Fuse this pattern to: SDPA*
- *This hotspot can be observed after
openvinotoolkit#26130, this PR's
implementation doesn't depend on it.*

### Tickets:
 - *153616*

---------

Signed-off-by: xipingya <[email protected]>
  • Loading branch information
xipingyan authored and CuriousPanCake committed Nov 6, 2024
1 parent 2092fa9 commit de05123
Show file tree
Hide file tree
Showing 9 changed files with 569 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/cpu_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{"Ngram", Type::Ngram},
{"ScaledDotProductAttention", Type::ScaledDotProductAttention},
{"ScaledDotProductAttentionWithKVCache", Type::ScaledDotProductAttention},
{"SDPAWithTransposeReshape", Type::ScaledDotProductAttention},
{"PagedAttentionExtension", Type::PagedAttention},
{"RoPE", Type::RoPE},
{"GatherCompressed", Type::Gather},
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TypeRelaxedExtension : public ov::OpExtension<ov::op::TypeRelaxed<Op>> {
OP_EXTENSION(ov::intel_cpu::PowerStaticNode) \
OP_EXTENSION(ov::intel_cpu::CausalMaskPreprocessNode) \
OP_EXTENSION(ov::intel_cpu::SwishNode) \
OP_EXTENSION(ov::intel_cpu::SDPAWithTransposeReshape) \
OP_EXTENSION(ov::intel_cpu::NgramNode) \
OP_EXTENSION(ov::op::internal::GatherCompressed) \
OP_EXTENSION(ov::op::internal::NonMaxSuppressionIEInternal) \
Expand Down
49 changes: 39 additions & 10 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
void execute(dnnl::stream strm, const Config& config, const std::vector<MemoryPtr>& inputs, const MemoryPtr output,
const MemoryPtr presentk_input, const MemoryPtr presentv_input, const MemoryPtr beam_input,
const PlainTensor& k_scale_zp, const PlainTensor& v_scale_zp) override {
bool has_in_reshape = config.config.input_BLHxS;
bool has_out_transpose = config.config.output_BLHxS;
bool fuse_causal_attn = config.config.fuse_causal_attn;
bool is_causal = config.config.is_causal;
Expand All @@ -881,11 +882,28 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
float scale_input = 0.0f;
size_t B, L1, L0, S, SV;

// B,L,H*S->B,L,H,S
auto get_reshape_shape = [&config](const PlainTensor& input) {
// [B,L,H*S]
auto inp_shape = input.shape();
// [B,L,H,S]
return VectorDims{inp_shape[0], inp_shape[1], config.config.order_HS[0], config.config.order_HS[1]};
};

q_input.reset(inputs[0]);
k_input.reset(inputs[1]);
v_input.reset(inputs[2]);
present_key.reset(presentk_input);
present_value.reset(presentv_input);
if (has_in_reshape) {
q_input = q_input.reshape(get_reshape_shape(q_input));
auto kv_shape = get_reshape_shape(k_input);
k_input = k_input.reshape(kv_shape);
v_input = v_input.reshape(kv_shape);
present_key = present_key.reshape(kv_shape);
present_value = present_value.reshape(kv_shape);
}

if (beam_input)
beam_table.reset(beam_input);
if (input_num > 3) {
Expand Down Expand Up @@ -985,11 +1003,11 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr<ov::N
OPENVINO_THROW("CPU: " + errorMessage);
}

const auto node = std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op);
if (node) {
if (const auto node = std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op)) {
m_config.config.is_causal = node->get_causal();
} else {
const auto node = std::dynamic_pointer_cast<const ScaledDotProductAttentionWithKVCache>(op);
} else if (const auto node = std::dynamic_pointer_cast<const ScaledDotProductAttentionWithKVCache>(op)) {
m_config.config = node->get_config();
} else if (const auto node = std::dynamic_pointer_cast<const SDPAWithTransposeReshape>(op)) {
m_config.config = node->get_config();
}
}
Expand Down Expand Up @@ -1142,17 +1160,28 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) {

bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
try {
auto sdpaWithTransposeReshapeOp = std::dynamic_pointer_cast<const SDPAWithTransposeReshape>(op);
if (!std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op) &&
!std::dynamic_pointer_cast<const ScaledDotProductAttentionWithKVCache>(op)) {
errorMessage = "Only ScaledDotProductAttention or ScaledDotProductAttentionWithKVCache operation are supported";
!std::dynamic_pointer_cast<const ScaledDotProductAttentionWithKVCache>(op) && !sdpaWithTransposeReshapeOp) {
errorMessage = "Only ScaledDotProductAttention, ScaledDotProductAttentionWithKVCache or "
"SDPAWithTransposeReshape operation are supported";
return false;
}
// expect shape of q: [B, H, L, S]
auto inRank = op->get_input_partial_shape(0).size();
if (inRank != 4u) {
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank);
return false;
if (sdpaWithTransposeReshapeOp) {
// inRank expect shape of q: [B, L, H*S]
if (inRank != 3u) {
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank);
return false;
}
} else {
// inRank expect shape of q: [B, H, L, S]
if (inRank != 4u) {
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank);
return false;
}
}

int orgSDPAInput = static_cast<int>(op->get_input_size());
const auto node = std::dynamic_pointer_cast<const ScaledDotProductAttentionWithKVCache>(op);
if (node) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,46 @@ bool ov::intel_cpu::ScaledDotProductAttentionWithKVCache::visit_attributes(ov::A
visitor.on_attribute("permute_axes", m_config.permute_axes);
visitor.finish_structure();
return true;
}

ov::intel_cpu::SDPAWithTransposeReshape::SDPAWithTransposeReshape(const OutputVector& args, const Config& cfg)
: Op(args),
m_config(cfg) {}

std::shared_ptr<ov::Node> ov::intel_cpu::SDPAWithTransposeReshape::clone_with_new_inputs(
const ov::OutputVector& new_args) const {
INTERNAL_OP_SCOPE(SDPAWithTransposeReshape_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<ov::intel_cpu::SDPAWithTransposeReshape>(new_args, m_config);
}

void ov::intel_cpu::SDPAWithTransposeReshape::validate_and_infer_types() {
INTERNAL_OP_SCOPE(SDPAWithTransposeReshape_validate_and_infer_types);
// [B,L,H*S]
auto q_ps = get_input_partial_shape(0);
auto output_ps = q_ps;
NODE_VALIDATION_CHECK(this, m_config.output_BLHxS == true);
NODE_VALIDATION_CHECK(this, m_config.input_BLHxS == true);
NODE_VALIDATION_CHECK(this, q_ps.size() == 3u);

// permute_axes should be [B, H, L, S]
const auto& permute_axes = this->m_config.permute_axes;
NODE_VALIDATION_CHECK(this, permute_axes.size() == 4u);

// order_HS should be [H,S]
const auto& order_HS = this->m_config.order_HS;
NODE_VALIDATION_CHECK(this, order_HS.size() == 2u);

set_output_type(0, get_input_element_type(0), output_ps);
}

bool ov::intel_cpu::SDPAWithTransposeReshape::visit_attributes(ov::AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(SDPAWithTransposeReshape_visit_attributes);
visitor.start_structure("config");
visitor.on_attribute("input_BLHxS", m_config.input_BLHxS);
visitor.on_attribute("output_BLHxS", m_config.output_BLHxS);
visitor.on_attribute("permute_axes", m_config.permute_axes);
visitor.on_attribute("order_HS", m_config.order_HS);
visitor.finish_structure();
return true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ class ScaledDotProductAttentionWithKVCache : public ov::op::Op {
ScaledDotProductAttentionWithKVCache() = default;

struct Config {
bool output_BLHxS = false; // true implies that output is [B,L,H*S]
bool input_BLHxS = false; // true implies that input is [B,L,H*S]
bool output_BLHxS = false; // true implies that output is [B,L,H*S]

bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask
bool is_causal = false; // apply causal mask internally
bool fuse_concat = false; // fuse (concat->sdp) ==> sdp
std::vector<size_t> permute_axes; // not empty means input has transpose. output of permutation is [B,H,L,S]
// e.g. [L,B,H,S] -> permute[1, 2, 0, 3] ->[B, H, L, S]
bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask
bool is_causal = false; // apply causal mask internally
bool fuse_concat = false; // fuse (concat->sdp) ==> sdp
std::vector<size_t> permute_axes; // not empty means input has transpose. output of permutation is [B,H,L,S]
// e.g. [L,B,H,S] -> permute[1, 2, 0, 3] ->[B, H, L, S]
std::vector<size_t> order_HS; // Reshape[B,L,H*S]->B,L,H,S], H,S are fixed value, when input_BLHxS is true.
};

ScaledDotProductAttentionWithKVCache(const OutputVector& args, const Config& cfg);
Expand All @@ -48,5 +50,30 @@ class ScaledDotProductAttentionWithKVCache : public ov::op::Op {
Config m_config;
};

class SDPAWithTransposeReshape : public ov::op::Op {
public:
OPENVINO_OP("SDPAWithTransposeReshape", "cpu_plugin_opset");
using Config = ScaledDotProductAttentionWithKVCache::Config;

SDPAWithTransposeReshape() = default;

SDPAWithTransposeReshape(const OutputVector& args, const Config& cfg);

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;

const Config& get_config() const {
return m_config;
}

Config& get_config() {
return m_config;
}

private:
Config m_config;
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "sdpa_fuse_transpose_reshape.hpp"

#include <transformations/utils/utils.hpp>

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/cpu_opset/common/op/sdpa.hpp"

/*
* Description: SDPA fuse transpose and reshape.
* Original pattern Fused pattern
*
* input1 input2 input3
* | | |
* q_reshape k_reshape v_reshap
* | | | (qkv transpose and reshape's orders)
* q_transpose k_transpose v_transpose |
* \ | / input1 input2 input3 |
* \ | / \ | / /
* ScaledDotProductAttention ---------> SDPAWithTransposeReshape
* | |
* out_transpose |
* | output
* out_reshpae
* |
* output
*/

using namespace ov;
using namespace ov::pass::pattern;

intel_cpu::SDPAFuseTransposeReshape::SDPAFuseTransposeReshape() {
MATCHER_SCOPE(SDPAFuseTransposeReshape);

auto q_reshape_node = wrap_type<op::v1::Reshape>({any_input(), any_input()});
auto k_reshape_node = wrap_type<op::v1::Reshape>({any_input(), any_input()});
auto v_reshape_node = wrap_type<op::v1::Reshape>({any_input(), any_input()});

auto q_transpose_order_node = wrap_type<op::v0::Constant>();
auto k_transpose_order_node = wrap_type<op::v0::Constant>();
auto v_transpose_order_node = wrap_type<op::v0::Constant>();
auto q_transpose_node = wrap_type<op::v1::Transpose>({q_reshape_node, q_transpose_order_node});
auto k_transpose_node = wrap_type<op::v1::Transpose>({k_reshape_node, k_transpose_order_node});
auto v_transpose_node = wrap_type<op::v1::Transpose>({v_reshape_node, v_transpose_order_node});

auto sdpa_node =
wrap_type<op::v13::ScaledDotProductAttention>({q_transpose_node, k_transpose_node, v_transpose_node});

auto out_transpose_order_node = wrap_type<op::v0::Constant>();
auto out_transpose_node = wrap_type<op::v1::Transpose>({sdpa_node, out_transpose_order_node});
auto out_reshape_node = wrap_type<op::v1::Reshape>({out_transpose_node, wrap_type<op::v0::Constant>()});

matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pass::pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();
auto sdpa = as_type_ptr<op::v13::ScaledDotProductAttention>(pattern_map.at(sdpa_node).get_node_shared_ptr());
if (sdpa == nullptr || transformation_callback(sdpa)) {
return false;
}

// Order=[0, 2, 1, 3]
auto is_expected_transpose = [&](std::shared_ptr<op::v1::Transpose>& transpose) {
if (transpose) {
const auto orders = as_type_ptr<op::v0::Constant>(transpose->get_input_node_shared_ptr(1));
return orders && (std::vector<int32_t>({0, 2, 1, 3}) == orders->cast_vector<int32_t>());
}
return false;
};

// Reshape [B,L,H*S] -> [B,L,H,S]
auto is_expected_reshape = [&](std::shared_ptr<op::v1::Reshape>& reshape_node, bool reverse = false) {
if (reshape_node) {
auto inp_shape = reshape_node->get_input_partial_shape(0);
auto outp_shape = reshape_node->get_output_partial_shape(0);
// Expect shape: [?, ?, val]
auto check_dim_3 = [](ov::PartialShape shape) {
return shape.rank().is_static() && shape.rank() == 3 && shape[2].is_static();
};
// Expect shape: [?, ?, val, val]
auto check_dim_4 = [](ov::PartialShape shape) {
return shape.rank().is_static() && shape.rank() == 4 && shape[2].is_static() &&
shape[3].is_static();
};

if (reverse) {
return check_dim_4(inp_shape) && check_dim_3(outp_shape) &&
(outp_shape[2] == inp_shape[2] * inp_shape[3]);
} else {
return check_dim_3(inp_shape) && check_dim_4(outp_shape) &&
(inp_shape[2] == outp_shape[2] * outp_shape[3]);
}
}
return false;
};

// Pattern: Reshape->Transpose->SDPA
auto q_reshape = as_type_ptr<op::v1::Reshape>(pattern_map.at(q_reshape_node).get_node_shared_ptr());
auto k_reshape = as_type_ptr<op::v1::Reshape>(pattern_map.at(k_reshape_node).get_node_shared_ptr());
auto v_reshape = as_type_ptr<op::v1::Reshape>(pattern_map.at(v_reshape_node).get_node_shared_ptr());

if (!(is_expected_reshape(q_reshape) && is_expected_reshape(k_reshape) && is_expected_reshape(v_reshape))) {
return false;
}
// K,V Reshape's order should be same node.
auto k_reshape_order = as_type_ptr<op::v0::Constant>(k_reshape->get_input_node_shared_ptr(1));
auto v_reshape_order = as_type_ptr<op::v0::Constant>(v_reshape->get_input_node_shared_ptr(1));
if (k_reshape_order && v_reshape_order) {
if (k_reshape_order->cast_vector<int32_t>() != v_reshape_order->cast_vector<int32_t>()) {
return false;
}
} else if (k_reshape->get_input_node_shared_ptr(1) != v_reshape->get_input_node_shared_ptr(1)) {
return false;
}

std::shared_ptr<op::v1::Transpose> qkv_transpose[3] = {};
std::shared_ptr<op::v0::Constant> qkv_transpose_order[3] = {};
qkv_transpose[0] = as_type_ptr<op::v1::Transpose>(pattern_map.at(q_transpose_node).get_node_shared_ptr());
qkv_transpose[1] = as_type_ptr<op::v1::Transpose>(pattern_map.at(k_transpose_node).get_node_shared_ptr());
qkv_transpose[2] = as_type_ptr<op::v1::Transpose>(pattern_map.at(v_transpose_node).get_node_shared_ptr());
qkv_transpose_order[0] = as_type_ptr<op::v0::Constant>(pattern_map.at(q_transpose_order_node).get_node_shared_ptr());
qkv_transpose_order[1] = as_type_ptr<op::v0::Constant>(pattern_map.at(k_transpose_order_node).get_node_shared_ptr());
qkv_transpose_order[2] = as_type_ptr<op::v0::Constant>(pattern_map.at(v_transpose_order_node).get_node_shared_ptr());
auto out_tranpose = as_type_ptr<op::v1::Transpose>(pattern_map.at(out_transpose_node).get_node_shared_ptr());
auto out_transpose_order = as_type_ptr<op::v0::Constant>(pattern_map.at(out_transpose_order_node).get_node_shared_ptr());

if (!(is_expected_transpose(qkv_transpose[0]) && is_expected_transpose(qkv_transpose[1]) &&
is_expected_transpose(qkv_transpose[2]))) {
return false;
}
if (!is_expected_transpose(out_tranpose)) {
return false;
}

auto out_reshape = as_type_ptr<op::v1::Reshape>(pattern_map.at(out_reshape_node).get_node_shared_ptr());
if (!is_expected_reshape(out_reshape, true)) {
return false;
}

OutputVector args = {q_reshape->get_input_node_shared_ptr(0),
k_reshape->get_input_node_shared_ptr(0),
v_reshape->get_input_node_shared_ptr(0)};

// Config
intel_cpu::SDPAWithTransposeReshape::Config config;
config.is_causal = sdpa->get_causal();
config.fuse_concat = false;
config.output_BLHxS = true;

// Config::permute_axes
const auto& permute_q = qkv_transpose_order[0]->cast_vector<int32_t>();
config.permute_axes.resize(permute_q.size());
for (size_t i = 0; i < permute_q.size(); i++) {
config.permute_axes[i] = static_cast<size_t>(permute_q[i]);
}

// Config::order_HS
config.order_HS.resize(2);
auto reshape_out_shape = q_reshape->get_output_partial_shape(0).get_min_shape(); // [?,?,H,S]
config.order_HS[0] = reshape_out_shape[2];
config.order_HS[1] = reshape_out_shape[3];
config.input_BLHxS = true;

auto new_sdpa = std::make_shared<intel_cpu::SDPAWithTransposeReshape>(args, config);
new_sdpa->set_friendly_name(sdpa->get_friendly_name() + "/fused_reshape_transpose");
NodeVector replaced_nodes = {q_reshape,
k_reshape,
v_reshape,
qkv_transpose[0],
qkv_transpose[1],
qkv_transpose[2],
sdpa,
out_tranpose,
out_reshape};
copy_runtime_info(replaced_nodes, new_sdpa);
ov::replace_node(out_reshape, new_sdpa);
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(out_reshape_node, matcher_name);
register_matcher(m, callback);
}
Loading

0 comments on commit de05123

Please sign in to comment.