Skip to content

Commit

Permalink
Fix issue 286 (MeasureTransport#403)
Browse files Browse the repository at this point in the history
* Fix issue 286

* Add binding change
  • Loading branch information
dannys4 authored Apr 2, 2024
1 parent 66dbc3d commit 1ae526b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
32 changes: 32 additions & 0 deletions bindings/python/src/MultiIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,37 @@

namespace py = pybind11;
using namespace mpart::binding;
template<typename Scalar_T>
using Matrix_Map_T = Eigen::Map<Eigen::Matrix<Scalar_T,Eigen::Dynamic,Eigen::Dynamic>, 0, Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>>;

mpart::MultiIndexSet MultiIndexSet_PyBuffer(py::buffer x){
constexpr bool rowMajor = Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic>::Flags & Eigen::RowMajorBit;

py::buffer_info info = x.request();

// Check for int32, int64
bool is_int32 = info.format == py::format_descriptor<int32_t>::format();
bool is_int64 = info.format == "l"; // This is based on a pybind bug; numpy int64 buffer is l, not q
if (!(is_int32 || is_int64))
throw std::runtime_error("Incompatible format: expected an array of either int32 or int64!");

if (info.ndim != 2)
throw std::runtime_error("Expected array with ndims = 2");

int stride_size = is_int32 ? sizeof(int32_t) : sizeof(int64_t);
Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> strides(
info.strides[rowMajor ? 0 : 1] / (py::ssize_t)stride_size,
info.strides[rowMajor ? 1 : 0] / (py::ssize_t)stride_size
);

if(is_int64) { // Is int64
Matrix_Map_T<int64_t> map_64 (static_cast<int64_t*>(info.ptr), info.shape[0], info.shape[1], strides);
return mpart::MultiIndexSet {map_64.cast<int32_t>()};
} else { // Is int32
Matrix_Map_T<int32_t> map (static_cast<int32_t*>(info.ptr), info.shape[0], info.shape[1], strides);
return mpart::MultiIndexSet {map};
}
}

void mpart::binding::MultiIndexWrapper(py::module &m)
{
Expand Down Expand Up @@ -111,6 +142,7 @@ void mpart::binding::MultiIndexWrapper(py::module &m)
// MultiIndexSet
py::class_<MultiIndexSet, std::shared_ptr<MultiIndexSet>>(m, "MultiIndexSet")
.def(py::init<const unsigned int>())
.def(py::init<std::function<MultiIndexSet(py::buffer)>>(&MultiIndexSet_PyBuffer))
.def(py::init<Eigen::Ref<const Eigen::MatrixXi> const&>())
.def("fix", &MultiIndexSet::Fix)
.def("__len__", &MultiIndexSet::Length, "Retrieves the length of _each_ multiindex within this set (i.e. the dimension of the input)")
Expand Down
8 changes: 8 additions & 0 deletions bindings/python/tests/test_MultiIndexSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
msetTensorProduct = mpart.MultiIndexSet.CreateTensorProduct(3,4,noneLim)
msetTotalOrder = mpart.MultiIndexSet.CreateTotalOrder(3,4,noneLim)

def test_create():
mset_one = mpart.MultiIndexSet([[2]])
assert mset_one.Size() == 1
assert len(mset_one) == 1
mset_one = mpart.MultiIndexSet(np.array([[2]]))
assert mset_one.Size() == 1
assert len(mset_one) == 1

def test_max_degrees():

assert np.all(msetFromArray.MaxOrders() == [2,1])
Expand Down

0 comments on commit 1ae526b

Please sign in to comment.