From 1ae526b6c53a3a70d9ee9e9bd636afd4e3fca584 Mon Sep 17 00:00:00 2001 From: Daniel <43151183+dannys4@users.noreply.github.com> Date: Tue, 2 Apr 2024 07:02:46 -0400 Subject: [PATCH] Fix issue 286 (#403) * Fix issue 286 * Add binding change --- bindings/python/src/MultiIndex.cpp | 32 +++++++++++++++++++++ bindings/python/tests/test_MultiIndexSet.py | 8 ++++++ 2 files changed, 40 insertions(+) diff --git a/bindings/python/src/MultiIndex.cpp b/bindings/python/src/MultiIndex.cpp index d97c9614..2e01bcd0 100644 --- a/bindings/python/src/MultiIndex.cpp +++ b/bindings/python/src/MultiIndex.cpp @@ -24,6 +24,37 @@ namespace py = pybind11; using namespace mpart::binding; +template +using Matrix_Map_T = Eigen::Map, 0, Eigen::Stride>; + +mpart::MultiIndexSet MultiIndexSet_PyBuffer(py::buffer x){ + constexpr bool rowMajor = Eigen::Matrix::Flags & Eigen::RowMajorBit; + + py::buffer_info info = x.request(); + + // Check for int32, int64 + bool is_int32 = info.format == py::format_descriptor::format(); + bool is_int64 = info.format == "l"; // This is based on a pybind bug; numpy int64 buffer is l, not q + if (!(is_int32 || is_int64)) + throw std::runtime_error("Incompatible format: expected an array of either int32 or int64!"); + + if (info.ndim != 2) + throw std::runtime_error("Expected array with ndims = 2"); + + int stride_size = is_int32 ? sizeof(int32_t) : sizeof(int64_t); + Eigen::Stride strides( + info.strides[rowMajor ? 0 : 1] / (py::ssize_t)stride_size, + info.strides[rowMajor ? 1 : 0] / (py::ssize_t)stride_size + ); + + if(is_int64) { // Is int64 + Matrix_Map_T map_64 (static_cast(info.ptr), info.shape[0], info.shape[1], strides); + return mpart::MultiIndexSet {map_64.cast()}; + } else { // Is int32 + Matrix_Map_T map (static_cast(info.ptr), info.shape[0], info.shape[1], strides); + return mpart::MultiIndexSet {map}; + } +} void mpart::binding::MultiIndexWrapper(py::module &m) { @@ -111,6 +142,7 @@ void mpart::binding::MultiIndexWrapper(py::module &m) // MultiIndexSet py::class_>(m, "MultiIndexSet") .def(py::init()) + .def(py::init>(&MultiIndexSet_PyBuffer)) .def(py::init const&>()) .def("fix", &MultiIndexSet::Fix) .def("__len__", &MultiIndexSet::Length, "Retrieves the length of _each_ multiindex within this set (i.e. the dimension of the input)") diff --git a/bindings/python/tests/test_MultiIndexSet.py b/bindings/python/tests/test_MultiIndexSet.py index 9b70ec4f..85fa4f3d 100644 --- a/bindings/python/tests/test_MultiIndexSet.py +++ b/bindings/python/tests/test_MultiIndexSet.py @@ -10,6 +10,14 @@ msetTensorProduct = mpart.MultiIndexSet.CreateTensorProduct(3,4,noneLim) msetTotalOrder = mpart.MultiIndexSet.CreateTotalOrder(3,4,noneLim) +def test_create(): + mset_one = mpart.MultiIndexSet([[2]]) + assert mset_one.Size() == 1 + assert len(mset_one) == 1 + mset_one = mpart.MultiIndexSet(np.array([[2]])) + assert mset_one.Size() == 1 + assert len(mset_one) == 1 + def test_max_degrees(): assert np.all(msetFromArray.MaxOrders() == [2,1])