Skip to content

Commit

Permalink
[cpu] Integrate IStaticShapeInfer wirth IShapeInfer (#27770)
Browse files Browse the repository at this point in the history
### Details:
 - The `IStaticShapeInfer` interface extends `IShapeInfer`.
- Remove `NgraphShapeInfer` class as its functionality is replaced by
`IStaticShapeInfer`.
- Refactor shape inference unit test to avoid names clashes with CPU
plugin types:
   - use `ov::Shape` to avoid interpretation as `intel_cpu::Shape`.
   - rename test type `ShapeVector` to `StaticShapeVector`.

### Tickets:
 - CVS-118704

---------

Signed-off-by: Pawel Raasz <[email protected]>
Signed-off-by: Raasz, Pawel <[email protected]>
  • Loading branch information
praasz authored Dec 4, 2024
1 parent 78a6ad8 commit c975788
Show file tree
Hide file tree
Showing 97 changed files with 1,077 additions and 1,214 deletions.
11 changes: 5 additions & 6 deletions src/plugins/intel_cpu/src/nodes/deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <common/primitive_desc.hpp>
#include <common/primitive_desc_iface.hpp>
#include "cpu/x64/cpu_isa_traits.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "shape_inference/shape_inference.hpp"

#include "eltwise.h"
#include "fake_quantize.h"
Expand Down Expand Up @@ -128,12 +128,11 @@ bool DeconvKey::operator==(const DeconvKey &rhs) const {
*/
class DeconfolutionShapeInferFactory : public ShapeInferFactory {
public:
DeconfolutionShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
DeconfolutionShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(std::move(op)) {}

ShapeInferPtr makeShapeInfer() const override {
if (m_op->get_input_size() > 2) {
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), PortMask(2));
}
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), EMPTY_PORT_MASK);
const auto port_mask = (m_op->get_input_size() > 2) ? PortMask(2) : EMPTY_PORT_MASK;
return make_shape_inference(m_op, port_mask);
}
private:
std::shared_ptr<ov::Node> m_op;
Expand Down
11 changes: 3 additions & 8 deletions src/plugins/intel_cpu/src/nodes/eye.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "openvino/op/eye.hpp"
#include <utils/bfloat16.hpp>
#include "openvino/core/parallel.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "shape_inference/shape_inference.hpp"
#include "utils/bfloat16.hpp"

#define THROW_ERROR(...) OPENVINO_THROW(NameFromType(getType()), " node with name '", getName(), "' ", __VA_ARGS__)
Expand All @@ -33,13 +33,8 @@ class EyeShapeInferFactory : public ShapeInferFactory {
public:
EyeShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
ShapeInferPtr makeShapeInfer() const override {
IShapeInfer::port_mask_t port_mask = EMPTY_PORT_MASK;
if (m_op->get_input_size() == 4) {
port_mask = PortMask(Eye::ROWS_NUM, Eye::COLS_NUM, Eye::DIAGONAL_INDEX, Eye::BATCH_SHAPE);
} else {
port_mask = PortMask(Eye::ROWS_NUM, Eye::COLS_NUM, Eye::DIAGONAL_INDEX);
}
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), port_mask);
return (m_op->get_input_size() == 4) ? make_shape_inference(m_op)
: make_shape_inference(m_op, PortMask(Eye::ROWS_NUM, Eye::COLS_NUM));
}
private:
std::shared_ptr<ov::Node> m_op;
Expand Down
17 changes: 5 additions & 12 deletions src/plugins/intel_cpu/src/nodes/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "openvino/opsets/opset11.hpp"
#include "openvino/opsets/opset4.hpp"
#include "shape_inference/shape_inference.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "shape_inference/static_shape.hpp"
#include "utils/bfloat16.hpp"
#include "utils/cpu_utils.hpp"
Expand Down Expand Up @@ -1763,27 +1762,21 @@ class InterpolateShapeInferFactory : public ShapeInferFactory {
public:
InterpolateShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
ShapeInferPtr makeShapeInfer() const override {
IShapeInfer::port_mask_t port_mask = 0x00;
if (auto interp4 = ov::as_type_ptr<ov::opset4::Interpolate>(m_op)) {
const auto &attr = interp4->get_attrs();

if (attr.shape_calculation_mode == ngInterpShapeCalcMode::SCALES) {
port_mask = PortMask(Interpolate::SCALES_ID, Interpolate::AXES_ID);
} else if (attr.shape_calculation_mode == ngInterpShapeCalcMode::SIZES) {
port_mask = PortMask(Interpolate::TARGET_SHAPE_ID, Interpolate::AXES_ID);
} else {
OPENVINO_ASSERT(false, "Unsupported interpolate shape calculation mode");
}
const auto is_supported_mode = (attr.shape_calculation_mode == ngInterpShapeCalcMode::SCALES) ||
(attr.shape_calculation_mode == ngInterpShapeCalcMode::SIZES);
OPENVINO_ASSERT(is_supported_mode, "Unsupported interpolate shape calculation mode");
return make_shape_inference(m_op);
} else if (auto interp11 = ov::as_type_ptr<ov::op::v11::Interpolate>(m_op)) {
port_mask = PortMask(Interpolate::SIZE_OR_SCALE_ID_V11, Interpolate::AXES_ID_V11);
return make_shape_inference(m_op);
} else {
OPENVINO_THROW("Shape infer factory cannot be created for ",
m_op->get_type_name(),
" node with name: ",
m_op->get_friendly_name(),
", only versions 4 and 11 are supported.");
}
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), port_mask);
}

private:
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/nodes/reference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include "reference.h"
#include "common/cpu_memcpy.h"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "shape_inference/shape_inference.hpp"

namespace ov {
namespace intel_cpu {
Expand All @@ -14,7 +14,7 @@ class ReferenceShapeInferFactory : public ShapeInferFactory {
ReferenceShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op{std::move(op)} {}

ShapeInferPtr makeShapeInfer() const override {
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), FULL_PORT_MASK);
return make_shape_inference(m_op, FULL_PORT_MASK);
}

private:
Expand Down
42 changes: 27 additions & 15 deletions src/plugins/intel_cpu/src/nodes/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@
#include "nodes/input.h"
#include "nodes/reorder.h"
#include "openvino/core/parallel.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "transformations/utils/utils.hpp"

#include "ov_ops/augru_cell.hpp"
#include "ov_ops/augru_sequence.hpp"
#include "openvino/op/gru_cell.hpp"
#include "openvino/op/gru_sequence.hpp"
#include "openvino/op/lstm_sequence.hpp"
#include "openvino/op/rnn_cell.hpp"
#include "openvino/op/rnn_sequence.hpp"
#include "ov_ops/augru_cell.hpp"
#include "ov_ops/augru_sequence.hpp"
#include "shape_inference/shape_inference.hpp"
#include "transformations/utils/utils.hpp"

using namespace dnnl;


namespace ov {
namespace intel_cpu {

namespace node {

static rnn_direction ieDirection2dnnl(const std::shared_ptr<const ov::Node>& op) {
Expand Down Expand Up @@ -356,19 +356,17 @@ namespace {
* dimentions permutation, necessary due to the mismatch between the ngrpah and the oneDNN RNN node descriptions.
*
*/
class RnnShapeInfer : public NgraphShapeInfer {
class RnnShapeInfer : public IShapeInfer {
public:
RnnShapeInfer(std::shared_ptr<ov::Node> op) :
NgraphShapeInfer(make_shape_inference(op), EMPTY_PORT_MASK) {
is_sequence = !(RNN::isCell(op));

native_order = RNN::testNativeOrder(op);
}
RnnShapeInfer(std::shared_ptr<ov::Node> op)
: is_sequence(!(RNN::isCell(op))),
native_order(RNN::testNativeOrder(op)),
m_shape_infer(make_shape_inference(std::move(op))) {}

Result infer(
const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
auto result = NgraphShapeInfer::infer(input_shapes, data_dependency);
auto result = m_shape_infer->infer(input_shapes, data_dependency);
if (ShapeInferStatus::success != result.status) {
OPENVINO_THROW("Unexpected: Unexpected shape inference result status");
}
Expand All @@ -382,10 +380,24 @@ class RnnShapeInfer : public NgraphShapeInfer {
return {std::move(originOutputShapes), result.status};
}

const ov::CoordinateDiff& get_pads_begin() override {
return m_shape_infer->get_pads_begin();
}

const ov::CoordinateDiff& get_pads_end() override {
return m_shape_infer->get_pads_end();
}

port_mask_t get_port_mask() const override {
return m_shape_infer->get_port_mask();
}

private:
bool is_sequence = false;
bool native_order = true;
bool is_sequence;
bool native_order;
ShapeInferPtr m_shape_infer;
};

class RnnShapeInferFactory final : public ShapeInferFactory {
public:
RnnShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
Expand Down
1 change: 0 additions & 1 deletion src/plugins/intel_cpu/src/nodes/strided_slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "common/cpu_memcpy.h"
#include "input.h"
#include "openvino/opsets/opset1.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "slice_shape_inference_utils.hpp"
#include "shape_inference/custom/strided_slice.hpp"

Expand Down
11 changes: 6 additions & 5 deletions src/plugins/intel_cpu/src/shape_inference/custom/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "matmul.hpp"
#include "utils.hpp"
#include "openvino/opsets/opset1.hpp"
#include "shape_inference/shape_inference.hpp"

namespace ov {
namespace intel_cpu {
Expand Down Expand Up @@ -64,17 +65,17 @@ Result MMShapeInfer::infer(

ShapeInferPtr MMShapeInferFactory::makeShapeInfer() const {
if (const auto matmul = ov::as_type_ptr<const ov::opset1::MatMul>(m_op)) {
const auto output_rank = matmul->get_output_partial_shape(0).rank().get_length();
const bool transpose_a = matmul->get_transpose_a();
const bool transpose_b = matmul->get_transpose_b();
const auto input_rank0 = matmul->get_input_partial_shape(0).rank().get_length();
const auto input_rank1 = matmul->get_input_partial_shape(1).rank().get_length();

if (input_rank0 == input_rank1) {
const auto output_rank = matmul->get_output_partial_shape(0).rank().get_length();
const bool transpose_a = matmul->get_transpose_a();
const bool transpose_b = matmul->get_transpose_b();
return std::make_shared<MMShapeInfer>(output_rank, transpose_a, transpose_b);
} else {
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), EMPTY_PORT_MASK);
return make_shape_inference(m_op);
}

} else {
OPENVINO_THROW("Unexpected operation type in the MatMul shape inference factory");
}
Expand Down
3 changes: 1 addition & 2 deletions src/plugins/intel_cpu/src/shape_inference/custom/matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
//

#include <node.h>

#include "shape_inference/shape_inference_cpu.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"

#pragma once
namespace ov {
Expand Down Expand Up @@ -42,4 +42,3 @@ class MMShapeInferFactory : public ShapeInferFactory {
} // namespace node
} // namespace intel_cpu
} // namespace ov

Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

#include "scaled_attn.hpp"

#include "shape_inference/shape_inference_cpu.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "shape_inference/shape_inference.hpp"
#include "transformations/cpu_opset/common/op/sdpa.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -78,7 +77,7 @@ ShapeInferPtr SDPAShapeInferFactory::makeShapeInfer() const {
return std::make_shared<SDPAShapeInfer>(config);
}
// fallback to ngraph shape infer on non-perf-critical case
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), EMPTY_PORT_MASK);
return make_shape_inference(m_op);
}

} // namespace node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "strided_slice.hpp"
#include "utils.hpp"
#include "slice_shape_inference.hpp"
#include "shape_inference/shape_inference_ngraph.hpp"
#include "shape_inference/shape_inference.hpp"

namespace ov {
namespace intel_cpu {
Expand Down Expand Up @@ -75,13 +75,13 @@ Result StridedSliceShapeInfer::infer(

ShapeInferPtr StridedSliceShapeInferFactory::makeShapeInfer() const {
if (const auto Slice_op = ov::as_type_ptr<const ov::op::v8::Slice>(m_op)) {
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), port_mask);
return make_shape_inference(m_op);
} else if (const auto SliceScatter_op = ov::as_type_ptr<const ov::op::v15::SliceScatter>(m_op)) {
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), PortMask(2, 3, 4, 5));
return make_shape_inference(m_op);
} else if (const auto StridedSlice_op = ov::as_type_ptr<const ov::op::v1::StridedSlice>(m_op)) {
const auto& ellipsis_mask = StridedSlice_op->get_ellipsis_mask();
if (std::any_of(ellipsis_mask.begin(), ellipsis_mask.end(), [](int64_t x){ return x == 1; })) {
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), port_mask);
return make_shape_inference(m_op);
} else {
auto vec_to_set = [](const std::vector<int64_t>& vec){
std::unordered_set<int64_t> to_set;
Expand Down
Loading

0 comments on commit c975788

Please sign in to comment.