Skip to content

Commit

Permalink
Refactor GLU shape_infer
Browse files Browse the repository at this point in the history
  • Loading branch information
mitruska committed Nov 28, 2024
1 parent 5c70d69 commit 63f1e00
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions src/core/shape_inference/include/glu_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,22 @@ namespace ov {
namespace op {
namespace internal {
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> shape_infer(const GLU* op,
const std::vector<TShape>& input_shapes,
const ITensorAccessor& tensor_accessor = make_tensor_accessor()) {
std::vector<TRShape> shape_infer(const GLU* op, const std::vector<TShape>& input_shapes) {
const auto inputs_count = input_shapes.size();
NODE_SHAPE_INFER_CHECK(op, input_shapes, inputs_count == 1);

int64_t axis = op->get_axis();
std::vector<int64_t> split_lengths = {op->get_split_lengths(), -1};
std::vector<TShape> variadic_split_input_shapes = {input_shapes[0], TShape{}, TShape{2}};

std::unordered_map<size_t, ov::Tensor> const_data;
const_data.emplace(1, ov::Tensor(ov::element::i64, ov::Shape{}, &axis));
const_data.emplace(2, ov::Tensor(ov::element::i64, ov::Shape{split_lengths.size()}, split_lengths.data()));

return ov::op::v1::variadic_split::shape_infer(op,
variadic_split_input_shapes,
ov::make_tensor_accessor(const_data));
const ov::Shape split_len_size{split_lengths.size()};
const ov::Shape scalar{};
std::vector<TShape> variadic_split_input_shapes{input_shapes[0], scalar, split_len_size};

return {
ov::op::variadic_split::shape_infer(op, variadic_split_input_shapes, ov::make_tensor_accessor(const_data))[0]};
}
} // namespace internal
} // namespace op
Expand Down

0 comments on commit 63f1e00

Please sign in to comment.