Skip to content

Commit

Permalink
[PyOV] new methods for Constant (#27732)
Browse files Browse the repository at this point in the history
### Details:
 - Add `get_tensor_view` and `get_strides` methods to for Constant
 - Add `__eq__` overloadings for Strides
### Tickets:
 - *ticket-id*
  • Loading branch information
akuporos authored Nov 26, 2024
1 parent 17e17b2 commit 09927d8
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
32 changes: 32 additions & 0 deletions src/bindings/python/src/pyopenvino/graph/ops/constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,22 @@ void regclass_graph_op_Constant(py::module m) {
}
});

constant.def("get_tensor_view",
&ov::op::v0::Constant::get_tensor_view,
R"(
Get view on constant data as tensor.
:rtype: openvino.Tensor
)");

constant.def("get_strides",
&ov::op::v0::Constant::get_strides,
R"(
Constant's strides in bytes.
:rtype: openvino.Strides
)");

// TODO: Remove in future and re-use `get_data`
// Provide buffer access
constant.def_buffer([](const ov::op::v0::Constant& self) -> py::buffer_info {
Expand Down Expand Up @@ -236,6 +252,22 @@ void regclass_graph_op_Constant(py::module m) {
:rtype: numpy.array
)");

constant.def_property_readonly("tensor_view",
&ov::op::v0::Constant::get_tensor_view,
R"(
Get view on constant data as tensor.
:rtype: openvino.Tensor
)");

constant.def_property_readonly("strides",
&ov::op::v0::Constant::get_strides,
R"(
Constant's strides in bytes.
:rtype: openvino.Strides
)");

constant.def("__repr__", [](const ov::op::v0::Constant& self) {
std::stringstream shapes_ss;
for (size_t i = 0; i < self.get_output_size(); ++i) {
Expand Down
29 changes: 29 additions & 0 deletions src/bindings/python/src/pyopenvino/graph/strides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@

namespace py = pybind11;

template <typename T>
bool compare_strides(const ov::Strides& a, const T& b) {
return a.size() == b.size() &&
std::equal(a.begin(), a.end(), b.begin(), [](const size_t& elem_a, const py::handle& elem_b) {
return elem_a == elem_b.cast<size_t>();
});
}

void regclass_graph_Strides(py::module m) {
py::class_<ov::Strides, std::shared_ptr<ov::Strides>> strides(m, "Strides");
strides.doc() = "openvino.runtime.Strides wraps ov::Strides";
Expand Down Expand Up @@ -48,6 +56,27 @@ void regclass_graph_Strides(py::module m) {
return self.size();
});

strides.def(
"__eq__",
[](const ov::Strides& a, const ov::Strides& b) {
return a == b;
},
py::is_operator());

strides.def(
"__eq__",
[](const ov::Strides& a, const py::tuple& b) {
return compare_strides<py::tuple>(a, b);
},
py::is_operator());

strides.def(
"__eq__",
[](const ov::Strides& a, const py::list& b) {
return compare_strides<py::list>(a, b);
},
py::is_operator());

strides.def(
"__iter__",
[](const ov::Strides& self) {
Expand Down
7 changes: 5 additions & 2 deletions src/bindings/python/tests/test_graph/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import openvino as ov

import openvino.runtime.opset13 as ops
from openvino import Type, PartialShape, Model, Tensor, compile_model
from openvino import Type, PartialShape, Model, Strides, Tensor, compile_model
from openvino.runtime.op import Constant
from openvino.helpers import pack_data, unpack_data

Expand Down Expand Up @@ -756,7 +756,7 @@ def test_get_data_casting_packed(src_dtype, ov_type, dst_dtype, copy_flag):
],
)
def test_const_from_tensor(shared_flag):
shape = [1, 3, 32, 32]
shape = [1, 2, 3, 3]
arr = np.ones(shape).astype(np.float32)
ov_tensor = Tensor(arr, shape, Type.f32)
ov_const = ops.constant(tensor=ov_tensor, shared_memory=shared_flag)
Expand All @@ -771,3 +771,6 @@ def test_const_from_tensor(shared_flag):
else:
assert not np.array_equal(ov_const.data, arr)
assert not np.shares_memory(arr, ov_const.data)

assert ov_const.strides == [72, 36, 12, 4]
assert ov_const.get_tensor_view().get_strides() == Strides([72, 36, 12, 4])

0 comments on commit 09927d8

Please sign in to comment.