Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/release'
Browse files Browse the repository at this point in the history
  • Loading branch information
dannys4 committed Apr 2, 2024
2 parents de521c1 + 1ae526b commit df06632
Show file tree
Hide file tree
Showing 14 changed files with 197 additions and 127 deletions.
1 change: 1 addition & 0 deletions .docker/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ dependencies:
- gcc
- gxx
- make
- dill
2 changes: 1 addition & 1 deletion .github/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ dependencies:
- nlopt >= 2.7
- pytorch
- cxx-compiler

- dill
1 change: 1 addition & 0 deletions .github/workflows/build-bindings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name: binding-tests
on:
push:
branches:
- release
- main
pull_request: {}

Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- release

jobs:
build-docs:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-external-lib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- release
pull_request: {}

env:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-push-docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- release

jobs:
docker:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- release
pull_request: {}

jobs:
Expand Down
121 changes: 5 additions & 116 deletions bindings/python/package/torch.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,13 @@
import torch

def ExtractTorchTensorData(tensor):
""" Extracts the pointer, shape, and stride from a pytorch tensor and returns a tuple
that can be passed to MParT functions that have been overloaded to accept
(double*, std::tuple<int,int>, std::tuple<int,int>) instead of a Kokkos::View.
Arguments:
------------
tensor: pytorch.Tensor
The pytorch tensor we want to eventually wrap with a Kokkos view.
from .torch_helpers import ExtractTorchTensorData, MpartTorchAutograd

Returns:
------------
Tuple[int, Tuple[int,int], Tuple[int,int]]
A python tuple that contains all information needed to construct a Kokkos::View.
After casting to c++ types using pybind, this output can be passed to the
mpart::ConstructViewFromPointer function.
"""

# Make sure the tensor has double data type
if tensor.dtype != torch.float64:
raise ValueError(f'Currently only tensors with float64 datatype can be converted. Current dtype is {tensor.dtype}')

if len(tensor.shape)==1:
return tensor.data_ptr(), tensor.shape[0], tensor.stride()[0]
elif len(tensor.shape)==2:
return tensor.data_ptr(), tuple(tensor.shape), tuple(tensor.stride())
else:
raise ValueError(f'Currently only 1d and 2d tensors can be converted.')


class MpartTorchAutograd(torch.autograd.Function):

@staticmethod
def forward(ctx, input, coeffs, f, return_logdet):
ctx.save_for_backward(input, coeffs)
ctx.f = f

coeffs_dbl = None
if coeffs is not None:
coeffs_dbl = coeffs.double()
f.WrapCoeffs(ExtractTorchTensorData(coeffs_dbl))
input_dbl = input.double()

output = torch.zeros(f.outputDim, input.shape[1], dtype=torch.double)
f.EvaluateImpl(ExtractTorchTensorData(input_dbl), ExtractTorchTensorData(output))

if return_logdet:
logdet = torch.zeros(input.shape[1], dtype=torch.double)
f.LogDeterminantImpl(ExtractTorchTensorData(input_dbl), ExtractTorchTensorData(logdet))
return output.type(input.dtype), logdet.type(input.dtype)
else:
return output.type(input.dtype)

@staticmethod
def backward(ctx, output_sens, logdet_sens=None):
input, coeffs = ctx.saved_tensors
f = ctx.f

coeffs_dbl = None
if coeffs is not None:
coeffs_dbl = coeffs.double()
f.WrapCoeffs(ExtractTorchTensorData(coeffs_dbl))
input_dbl = input.double()
output_sens_dbl = output_sens.double()

logdet_sens_dbl = None
if logdet_sens is not None:
logdet_sens_dbl = logdet_sens.double()

# Get the gradient wrt input
grad = None
if input.requires_grad:
grad = torch.zeros(f.inputDim, input.shape[1], dtype=torch.double)

f.GradientImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(output_sens_dbl),
ExtractTorchTensorData(grad))

if logdet_sens is not None:
grad2 = torch.zeros(f.inputDim, input.shape[1], dtype=torch.double)

f.LogDeterminantInputGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(grad2))
grad += grad2*logdet_sens_dbl[None,:]

coeff_grad = None
if coeffs is not None:
if coeffs.requires_grad:
coeff_grad = torch.zeros(f.numCoeffs, input.shape[1], dtype=torch.double)
f.CoeffGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(output_sens_dbl),
ExtractTorchTensorData(coeff_grad))

coeff_grad = coeff_grad.sum(axis=1) # pytorch expects total gradient not per-sample gradient

if logdet_sens is not None:
grad2 = torch.zeros(f.numCoeffs, input.shape[1], dtype=torch.double)

f.LogDeterminantCoeffGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(grad2))

coeff_grad += torch.sum(grad2*logdet_sens[None,:],axis=1)

if coeff_grad is not None:
coeff_grad = coeff_grad.type(input.dtype)

if grad is not None:
grad = grad.type(input.dtype)

return grad, coeff_grad, None, None



class TorchParameterizedFunctionBase(torch.nn.Module):
""" Defines a wrapper around the MParT ParameterizedFunctionBase class that
can be used with pytorch.
"""

def __init__(self, f, store_coeffs=True, dtype=torch.double):
def __init__(self, f=None, store_coeffs=True, dtype=torch.double):
super().__init__()

self.f = f
Expand All @@ -129,7 +18,7 @@ def __init__(self, f, store_coeffs=True, dtype=torch.double):
self.coeffs = torch.nn.Parameter(coeff_tensor)
else:
self.coeffs = None

def forward(self, x, coeffs=None):

if coeffs is None:
Expand All @@ -148,7 +37,7 @@ class TorchConditionalMapBase(torch.nn.Module):
This can be done either in the constructor or afterwards.
"""

def __init__(self, f, store_coeffs=True, return_logdet=False, dtype=torch.double):
def __init__(self, f=None, store_coeffs=True, return_logdet=False, dtype=torch.double):
super().__init__()

self.return_logdet = return_logdet
Expand All @@ -159,7 +48,7 @@ def __init__(self, f, store_coeffs=True, return_logdet=False, dtype=torch.double
self.coeffs = torch.nn.Parameter(coeff_tensor)
else:
self.coeffs = None

def forward(self, x, coeffs=None):

if coeffs is None:
Expand Down
115 changes: 115 additions & 0 deletions bindings/python/package/torch_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch

def ExtractTorchTensorData(tensor):
""" Extracts the pointer, shape, and stride from a pytorch tensor and returns a tuple
that can be passed to MParT functions that have been overloaded to accept
(double*, std::tuple<int,int>, std::tuple<int,int>) instead of a Kokkos::View.
Arguments:
------------
tensor: pytorch.Tensor
The pytorch tensor we want to eventually wrap with a Kokkos view.
Returns:
------------
Tuple[int, Tuple[int,int], Tuple[int,int]]
A python tuple that contains all information needed to construct a Kokkos::View.
After casting to c++ types using pybind, this output can be passed to the
mpart::ConstructViewFromPointer function.
"""

# Make sure the tensor has double data type
if tensor.dtype != torch.float64:
raise ValueError(f'Currently only tensors with float64 datatype can be converted. Current dtype is {tensor.dtype}')

if len(tensor.shape)==1:
return tensor.data_ptr(), tensor.shape[0], tensor.stride()[0]
elif len(tensor.shape)==2:
return tensor.data_ptr(), tuple(tensor.shape), tuple(tensor.stride())
else:
raise ValueError(f'Currently only 1d and 2d tensors can be converted.')


class MpartTorchAutograd(torch.autograd.Function):

def __reduce__(self):
return (self.__class__, (None,))

@staticmethod
def forward(ctx, input, coeffs, f, return_logdet):
ctx.save_for_backward(input, coeffs)
ctx.f = f

coeffs_dbl = None
if coeffs is not None:
coeffs_dbl = coeffs.double()
f.WrapCoeffs(ExtractTorchTensorData(coeffs_dbl))
input_dbl = input.double()

output = torch.zeros(f.outputDim, input.shape[1], dtype=torch.double)
f.EvaluateImpl(ExtractTorchTensorData(input_dbl), ExtractTorchTensorData(output))

if return_logdet:
logdet = torch.zeros(input.shape[1], dtype=torch.double)
f.LogDeterminantImpl(ExtractTorchTensorData(input_dbl), ExtractTorchTensorData(logdet))
return output.type(input.dtype), logdet.type(input.dtype)
else:
return output.type(input.dtype)

@staticmethod
def backward(ctx, output_sens, logdet_sens=None):
input, coeffs = ctx.saved_tensors
f = ctx.f

coeffs_dbl = None
if coeffs is not None:
coeffs_dbl = coeffs.double()
f.WrapCoeffs(ExtractTorchTensorData(coeffs_dbl))
input_dbl = input.double()
output_sens_dbl = output_sens.double()

logdet_sens_dbl = None
if logdet_sens is not None:
logdet_sens_dbl = logdet_sens.double()

# Get the gradient wrt input
grad = None
if input.requires_grad:
grad = torch.zeros(f.inputDim, input.shape[1], dtype=torch.double)

f.GradientImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(output_sens_dbl),
ExtractTorchTensorData(grad))

if logdet_sens is not None:
grad2 = torch.zeros(f.inputDim, input.shape[1], dtype=torch.double)

f.LogDeterminantInputGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(grad2))
grad += grad2*logdet_sens_dbl[None,:]

coeff_grad = None
if coeffs is not None:
if coeffs.requires_grad:
coeff_grad = torch.zeros(f.numCoeffs, input.shape[1], dtype=torch.double)
f.CoeffGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(output_sens_dbl),
ExtractTorchTensorData(coeff_grad))

coeff_grad = coeff_grad.sum(axis=1) # pytorch expects total gradient not per-sample gradient

if logdet_sens is not None:
grad2 = torch.zeros(f.numCoeffs, input.shape[1], dtype=torch.double)

f.LogDeterminantCoeffGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(grad2))

coeff_grad += torch.sum(grad2*logdet_sens[None,:],axis=1)

if coeff_grad is not None:
coeff_grad = coeff_grad.type(input.dtype)

if grad is not None:
grad = grad.type(input.dtype)

return grad, coeff_grad, None, None
32 changes: 32 additions & 0 deletions bindings/python/src/MultiIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,37 @@

namespace py = pybind11;
using namespace mpart::binding;
template<typename Scalar_T>
using Matrix_Map_T = Eigen::Map<Eigen::Matrix<Scalar_T,Eigen::Dynamic,Eigen::Dynamic>, 0, Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>>;

mpart::MultiIndexSet MultiIndexSet_PyBuffer(py::buffer x){
constexpr bool rowMajor = Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic>::Flags & Eigen::RowMajorBit;

py::buffer_info info = x.request();

// Check for int32, int64
bool is_int32 = info.format == py::format_descriptor<int32_t>::format();
bool is_int64 = info.format == "l"; // This is based on a pybind bug; numpy int64 buffer is l, not q
if (!(is_int32 || is_int64))
throw std::runtime_error("Incompatible format: expected an array of either int32 or int64!");

if (info.ndim != 2)
throw std::runtime_error("Expected array with ndims = 2");

int stride_size = is_int32 ? sizeof(int32_t) : sizeof(int64_t);
Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> strides(
info.strides[rowMajor ? 0 : 1] / (py::ssize_t)stride_size,
info.strides[rowMajor ? 1 : 0] / (py::ssize_t)stride_size
);

if(is_int64) { // Is int64
Matrix_Map_T<int64_t> map_64 (static_cast<int64_t*>(info.ptr), info.shape[0], info.shape[1], strides);
return mpart::MultiIndexSet {map_64.cast<int32_t>()};
} else { // Is int32
Matrix_Map_T<int32_t> map (static_cast<int32_t*>(info.ptr), info.shape[0], info.shape[1], strides);
return mpart::MultiIndexSet {map};
}
}

void mpart::binding::MultiIndexWrapper(py::module &m)
{
Expand Down Expand Up @@ -112,6 +143,7 @@ void mpart::binding::MultiIndexWrapper(py::module &m)
// MultiIndexSet
py::class_<MultiIndexSet, std::shared_ptr<MultiIndexSet>>(m, "MultiIndexSet")
.def(py::init<const unsigned int>())
.def(py::init<std::function<MultiIndexSet(py::buffer)>>(&MultiIndexSet_PyBuffer))
.def(py::init<Eigen::Ref<const Eigen::MatrixXi> const&>())
.def("fix", &MultiIndexSet::Fix)
.def("__len__", &MultiIndexSet::Length, "Retrieves the length of _each_ multiindex within this set (i.e. the dimension of the input)")
Expand Down
8 changes: 8 additions & 0 deletions bindings/python/tests/test_MultiIndexSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
msetTensorProduct = mpart.MultiIndexSet.CreateTensorProduct(dim,power,noneLim)
msetTotalOrder = mpart.MultiIndexSet.CreateTotalOrder(dim,power,noneLim)

def test_create():
mset_one = mpart.MultiIndexSet([[2]])
assert mset_one.Size() == 1
assert len(mset_one) == 1
mset_one = mpart.MultiIndexSet(np.array([[2]]))
assert mset_one.Size() == 1
assert len(mset_one) == 1

def test_max_degrees():

assert np.all(msetFromArray.MaxOrders() == [2,1])
Expand Down
Loading

0 comments on commit df06632

Please sign in to comment.