Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cpu] Integrate IStaticShapeInfer wirth IShapeInfer #27770

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/plugins/intel_cpu/src/nodes/deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <common/primitive_desc.hpp>
#include <common/primitive_desc_iface.hpp>
#include "cpu/x64/cpu_isa_traits.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "shape_inference/shape_inference.hpp"

#include "eltwise.h"
#include "fake_quantize.h"
Expand Down Expand Up @@ -128,12 +128,11 @@ bool DeconvKey::operator==(const DeconvKey &rhs) const {
*/
class DeconfolutionShapeInferFactory : public ShapeInferFactory {
public:
DeconfolutionShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
DeconfolutionShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(std::move(op)) {}

ShapeInferPtr makeShapeInfer() const override {
if (m_op->get_input_size() > 2) {
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), PortMask(2));
}
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), EMPTY_PORT_MASK);
const auto port_mask = (m_op->get_input_size() > 2) ? PortMask(2) : EMPTY_PORT_MASK;
return make_shape_inference(m_op, port_mask);
}
private:
std::shared_ptr<ov::Node> m_op;
Expand Down
11 changes: 3 additions & 8 deletions src/plugins/intel_cpu/src/nodes/eye.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "openvino/op/eye.hpp"
#include <utils/bfloat16.hpp>
#include "openvino/core/parallel.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "shape_inference/shape_inference.hpp"
#include "utils/bfloat16.hpp"

#define THROW_ERROR(...) OPENVINO_THROW(NameFromType(getType()), " node with name '", getName(), "' ", __VA_ARGS__)
Expand All @@ -33,13 +33,8 @@ class EyeShapeInferFactory : public ShapeInferFactory {
public:
EyeShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
ShapeInferPtr makeShapeInfer() const override {
IShapeInfer::port_mask_t port_mask = EMPTY_PORT_MASK;
if (m_op->get_input_size() == 4) {
port_mask = PortMask(Eye::ROWS_NUM, Eye::COLS_NUM, Eye::DIAGONAL_INDEX, Eye::BATCH_SHAPE);
} else {
port_mask = PortMask(Eye::ROWS_NUM, Eye::COLS_NUM, Eye::DIAGONAL_INDEX);
}
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), port_mask);
return (m_op->get_input_size() == 4) ? make_shape_inference(m_op)
: make_shape_inference(m_op, PortMask(Eye::ROWS_NUM, Eye::COLS_NUM));
}
private:
std::shared_ptr<ov::Node> m_op;
Expand Down
17 changes: 5 additions & 12 deletions src/plugins/intel_cpu/src/nodes/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "openvino/opsets/opset11.hpp"
#include "openvino/opsets/opset4.hpp"
#include "shape_inference/shape_inference.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "shape_inference/static_shape.hpp"
#include "utils/bfloat16.hpp"
#include "utils/cpu_utils.hpp"
Expand Down Expand Up @@ -1763,27 +1762,21 @@ class InterpolateShapeInferFactory : public ShapeInferFactory {
public:
InterpolateShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
ShapeInferPtr makeShapeInfer() const override {
IShapeInfer::port_mask_t port_mask = 0x00;
if (auto interp4 = ov::as_type_ptr<ov::opset4::Interpolate>(m_op)) {
const auto &attr = interp4->get_attrs();

if (attr.shape_calculation_mode == ngInterpShapeCalcMode::SCALES) {
port_mask = PortMask(Interpolate::SCALES_ID, Interpolate::AXES_ID);
} else if (attr.shape_calculation_mode == ngInterpShapeCalcMode::SIZES) {
port_mask = PortMask(Interpolate::TARGET_SHAPE_ID, Interpolate::AXES_ID);
} else {
OPENVINO_ASSERT(false, "Unsupported interpolate shape calculation mode");
}
const auto is_supported_mode = (attr.shape_calculation_mode == ngInterpShapeCalcMode::SCALES) ||
(attr.shape_calculation_mode == ngInterpShapeCalcMode::SIZES);
OPENVINO_ASSERT(is_supported_mode, "Unsupported interpolate shape calculation mode");
return make_shape_inference(m_op);
praasz marked this conversation as resolved.
Show resolved Hide resolved
} else if (auto interp11 = ov::as_type_ptr<ov::op::v11::Interpolate>(m_op)) {
port_mask = PortMask(Interpolate::SIZE_OR_SCALE_ID_V11, Interpolate::AXES_ID_V11);
return make_shape_inference(m_op);
} else {
OPENVINO_THROW("Shape infer factory cannot be created for ",
m_op->get_type_name(),
" node with name: ",
m_op->get_friendly_name(),
", only versions 4 and 11 are supported.");
}
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), port_mask);
}

private:
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/nodes/reference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include "reference.h"
#include "common/cpu_memcpy.h"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "shape_inference/shape_inference.hpp"

namespace ov {
namespace intel_cpu {
Expand All @@ -14,7 +14,7 @@ class ReferenceShapeInferFactory : public ShapeInferFactory {
ReferenceShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op{std::move(op)} {}

ShapeInferPtr makeShapeInfer() const override {
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), FULL_PORT_MASK);
return make_shape_inference(m_op, FULL_PORT_MASK);
}

private:
Expand Down
44 changes: 29 additions & 15 deletions src/plugins/intel_cpu/src/nodes/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,24 @@
#include "nodes/input.h"
#include "nodes/reorder.h"
#include "openvino/core/parallel.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "transformations/utils/utils.hpp"

#include "ov_ops/augru_cell.hpp"
#include "ov_ops/augru_sequence.hpp"
#include "openvino/op/gru_cell.hpp"
#include "openvino/op/gru_sequence.hpp"
#include "openvino/op/lstm_sequence.hpp"
#include "openvino/op/rnn_cell.hpp"
#include "openvino/op/rnn_sequence.hpp"
#include "ov_ops/augru_cell.hpp"
#include "ov_ops/augru_sequence.hpp"
#include "shape_inference/shape_inference.hpp"
#include "transformations/utils/utils.hpp"

using namespace dnnl;


namespace ov {
namespace intel_cpu {

class ShapeInferBase;
class ShapeInferCustomMask;
praasz marked this conversation as resolved.
Show resolved Hide resolved
namespace node {

static rnn_direction ieDirection2dnnl(const std::shared_ptr<const ov::Node>& op) {
Expand Down Expand Up @@ -356,19 +358,17 @@ namespace {
* dimentions permutation, necessary due to the mismatch between the ngrpah and the oneDNN RNN node descriptions.
*
*/
class RnnShapeInfer : public NgraphShapeInfer {
class RnnShapeInfer : public IShapeInfer {
public:
RnnShapeInfer(std::shared_ptr<ov::Node> op) :
NgraphShapeInfer(make_shape_inference(op), EMPTY_PORT_MASK) {
is_sequence = !(RNN::isCell(op));

native_order = RNN::testNativeOrder(op);
}
RnnShapeInfer(std::shared_ptr<ov::Node> op)
: is_sequence(!(RNN::isCell(op))),
native_order(RNN::testNativeOrder(op)),
m_shape_infer(make_shape_inference(std::move(op))) {}

Result infer(
const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
auto result = NgraphShapeInfer::infer(input_shapes, data_dependency);
auto result = m_shape_infer->infer(input_shapes, data_dependency);
if (ShapeInferStatus::success != result.status) {
OPENVINO_THROW("Unexpected: Unexpected shape inference result status");
}
Expand All @@ -382,10 +382,24 @@ class RnnShapeInfer : public NgraphShapeInfer {
return {std::move(originOutputShapes), result.status};
}

const ov::CoordinateDiff& get_pads_begin() override {
return m_shape_infer->get_pads_begin();
}

const ov::CoordinateDiff& get_pads_end() override {
return m_shape_infer->get_pads_end();
}

port_mask_t get_port_mask() const override {
return m_shape_infer->get_port_mask();
}

private:
bool is_sequence = false;
bool native_order = true;
bool is_sequence;
bool native_order;
ShapeInferPtr m_shape_infer;
};

class RnnShapeInferFactory final : public ShapeInferFactory {
public:
RnnShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
Expand Down
1 change: 0 additions & 1 deletion src/plugins/intel_cpu/src/nodes/strided_slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "common/cpu_memcpy.h"
#include "input.h"
#include "openvino/opsets/opset1.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "slice_shape_inference_utils.hpp"
#include "shape_inference/custom/strided_slice.hpp"

Expand Down
11 changes: 6 additions & 5 deletions src/plugins/intel_cpu/src/shape_inference/custom/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "matmul.hpp"
#include "utils.hpp"
#include "openvino/opsets/opset1.hpp"
#include "shape_inference/shape_inference.hpp"

namespace ov {
namespace intel_cpu {
Expand Down Expand Up @@ -64,17 +65,17 @@ Result MMShapeInfer::infer(

ShapeInferPtr MMShapeInferFactory::makeShapeInfer() const {
if (const auto matmul = ov::as_type_ptr<const ov::opset1::MatMul>(m_op)) {
const auto output_rank = matmul->get_output_partial_shape(0).rank().get_length();
const bool transpose_a = matmul->get_transpose_a();
const bool transpose_b = matmul->get_transpose_b();
const auto input_rank0 = matmul->get_input_partial_shape(0).rank().get_length();
const auto input_rank1 = matmul->get_input_partial_shape(1).rank().get_length();

if (input_rank0 == input_rank1) {
const auto output_rank = matmul->get_output_partial_shape(0).rank().get_length();
const bool transpose_a = matmul->get_transpose_a();
const bool transpose_b = matmul->get_transpose_b();
return std::make_shared<MMShapeInfer>(output_rank, transpose_a, transpose_b);
} else {
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), EMPTY_PORT_MASK);
return make_shape_inference(m_op);
}

} else {
OPENVINO_THROW("Unexpected operation type in the MatMul shape inference factory");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
//

#include <node.h>

#include "shape_inference/shape_inference_cpu.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"

#pragma once
namespace ov {
Expand Down Expand Up @@ -42,4 +42,3 @@ class MMShapeInferFactory : public ShapeInferFactory {
} // namespace node
} // namespace intel_cpu
} // namespace ov

Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

#include "scaled_attn.hpp"

#include "shape_inference/shape_inference_cpu.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "shape_inference/shape_inference.hpp"
#include "transformations/cpu_opset/common/op/sdpa.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -78,7 +77,7 @@ ShapeInferPtr SDPAShapeInferFactory::makeShapeInfer() const {
return std::make_shared<SDPAShapeInfer>(config);
}
// fallback to ngraph shape infer on non-perf-critical case
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), EMPTY_PORT_MASK);
return make_shape_inference(m_op);
}

} // namespace node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "strided_slice.hpp"
#include "utils.hpp"
#include "slice_shape_inference.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "shape_inference/shape_inference.hpp"

namespace ov {
namespace intel_cpu {
Expand Down Expand Up @@ -75,13 +75,13 @@ Result StridedSliceShapeInfer::infer(

ShapeInferPtr StridedSliceShapeInferFactory::makeShapeInfer() const {
if (const auto Slice_op = ov::as_type_ptr<const ov::op::v8::Slice>(m_op)) {
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), port_mask);
return make_shape_inference(m_op);
} else if (const auto SliceScatter_op = ov::as_type_ptr<const ov::op::v15::SliceScatter>(m_op)) {
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), PortMask(2, 3, 4, 5));
return make_shape_inference(m_op);
} else if (const auto StridedSlice_op = ov::as_type_ptr<const ov::op::v1::StridedSlice>(m_op)) {
const auto& ellipsis_mask = StridedSlice_op->get_ellipsis_mask();
if (std::any_of(ellipsis_mask.begin(), ellipsis_mask.end(), [](int64_t x){ return x == 1; })) {
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), port_mask);
return make_shape_inference(m_op);
} else {
auto vec_to_set = [](const std::vector<int64_t>& vec){
std::unordered_set<int64_t> to_set;
Expand Down
Loading
Loading