Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Bindings for RMVE #8

Merged
merged 3 commits into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bindings/matlab/include/MexWrapperTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class ConditionalMapMex { // The class
map_ptr = MapFactory::CreateSigmoidComponent<MemorySpace>(inputDim,totalOrder,centers,opts);
}

ConditionalMapMex(FixedMultiIndexSet<MemorySpace> mset_offdiag, FixedMultiIndexSet<MemorySpace> mset_diag, StridedVector<const double, MemorySpace> centers, MapOptions opts){
map_ptr = MapFactory::CreateSigmoidComponent<MemorySpace>(mset_offdiag,mset_diag,centers,opts);
ConditionalMapMex(FixedMultiIndexSet<MemorySpace> mset, StridedVector<const double, MemorySpace> centers, MapOptions opts){
map_ptr = MapFactory::CreateSigmoidComponent<MemorySpace>(mset,centers,opts);
}

ConditionalMapMex(unsigned int inputDim, unsigned int outputDim, unsigned int totalOrder, StridedMatrix<const double, MemorySpace> centers, MapOptions opts){
Expand Down
21 changes: 14 additions & 7 deletions bindings/matlab/mat/ConditionalMap.m
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
methods

function this = ConditionalMap(varargin)

if(nargin==2)
if(isstring(varargin{2}))
if(varargin{2}=="id")
Expand Down Expand Up @@ -60,7 +59,19 @@
end
elseif(nargin==3)
if(isstring(varargin{3}) && varargin{3}=="Ab")
this.id_=MParT_('ConditionalMap_newAffineMapAb', varargin{1},varargin{2});
this.id_=MParT_('ConditionalMap_newAffineMapAb',varargin{1},varargin{2});
elseif(isa(varargin{1},'FixedMultiIndexSet'))
mset = varargin{1};
centers = varargin{2};
mexOptions = varargin{3}.getMexOptions;
input_str = ['MParT_(',char(39),'ConditionalMap_newSigmoidCompFromMset',char(39)];
input_str = [input_str,',mset.get_id(),centers'];
for o=1:length(mexOptions)
input_o=[',mexOptions{',num2str(o),'}'];
input_str=[input_str,input_o];
end
input_str=[input_str,')'];
this.id_ = eval(input_str);
else
error("Wrong input arguments");
end
Expand All @@ -71,11 +82,7 @@
opts = varargin{4};

mexOptions = opts.getMexOptions;
if isa(inputDim, 'FixedMultiIndexSet') % If the first arguments are multi-index sets, we call CreateSigmoidComponent from msets
fcn_name = 'SigmoidCompFromMsets'; % arguments (mset_offdiag, mset_diag, centers, opts)
inputDim = inputDim.get_id(); % Need to get the IDs, these are multi-index sets
outputDim = outputDim.get_id();
elseif numel(totalOrder)==1 % if totalOrder is a scalar, this is calling CreateTriangular
if numel(totalOrder)==1 % if totalOrder is a scalar, this is calling CreateTriangular
fcn_name = 'TotalTriMap';
else % otherwise, we call CreateSigmoidComponent, args (inputDim, totalOrder, centers, opts)
fcn_name = 'SigmoidComp';
Expand Down
15 changes: 13 additions & 2 deletions bindings/matlab/mat/CreateSigmoidComponent.m
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
function map = CreateSigmoidComponent(inputDim, totalOrder, centers, options)
map = ConditionalMap(inputDim, totalOrder, centers, options);
function map = CreateSigmoidComponent(varargin)
if nargin == 3
mset = varargin{1};
centers = varargin{2};
options = varargin{3};
map = ConditionalMap(mset, centers, options);
elseif nargin == 4
inputDim = varargin{1};
totalOrder = varargin{2};
centers = varargin{3};
options = varargin{4};
map = ConditionalMap(inputDim, totalOrder, centers, options);
end
end
13 changes: 6 additions & 7 deletions bindings/matlab/src/ConditionalMap_mex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,14 @@ MEX_DEFINE(ConditionalMap_newSigmoidComp) (int nlhs, mxArray* plhs[],
output.set(0, Session<ConditionalMapMex>::create(new ConditionalMapMex(inputDim,totalOrder,centers,opts)));
}

MEX_DEFINE(ConditionalMap_newSigmoidCompFromMsets) (int nlhs, mxArray* plhs[],
MEX_DEFINE(ConditionalMap_newSigmoidCompFromMset) (int nlhs, mxArray* plhs[],
int nrhs, const mxArray* prhs[]) {
InputArguments input(nrhs, prhs, 3 + MPART_MEX_MAPOPTIONS_ARGCOUNT);
InputArguments input(nrhs, prhs, 2 + MPART_MEX_MAPOPTIONS_ARGCOUNT);
OutputArguments output(nlhs, plhs, 1);
const FixedMultiIndexSet<MemorySpace>& mset_offdiag = Session<FixedMultiIndexSet<MemorySpace>>::getConst(input.get(0));
const FixedMultiIndexSet<MemorySpace>& mset_diag = Session<FixedMultiIndexSet<MemorySpace>>::getConst(input.get(1));
StridedVector<const double, Kokkos::HostSpace> centers = MexToKokkos1d(prhs[2]);
MapOptions opts = binding::MapOptionsFromMatlab(input, 3);
output.set(0, Session<ConditionalMapMex>::create(new ConditionalMapMex(mset_offdiag, mset_diag, centers, opts)));
const FixedMultiIndexSet<MemorySpace>& mset = Session<FixedMultiIndexSet<MemorySpace>>::getConst(input.get(0));
StridedVector<const double, Kokkos::HostSpace> centers = MexToKokkos1d(prhs[1]);
MapOptions opts = binding::MapOptionsFromMatlab(input, 2);
output.set(0, Session<ConditionalMapMex>::create(new ConditionalMapMex(mset, centers, opts)));
}

MEX_DEFINE(ConditionalMap_newMap) (int nlhs, mxArray* plhs[],
Expand Down
1 change: 1 addition & 0 deletions bindings/matlab/tests/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.xml
7 changes: 3 additions & 4 deletions bindings/matlab/tests/SigmoidTest.m
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ function SigmoidComponentMset( testCase )
max_order = num_sigmoids_order+2;
opts = MapOptions;
opts.basisType = BasisTypes.HermiteFunctions;
mset_off = FixedMultiIndexSet(input_dim-1, max_order);
mset_diag = MultiIndexSet.CreateNonzeroDiagTotalOrder(input_dim, max_order).Fix();
comp = CreateSigmoidComponent(mset_off, mset_diag, centers, opts);
expected_coeffs = mset_off.Size() + mset_diag.Size();
mset = FixedMultiIndexSet(input_dim, max_order);
comp = CreateSigmoidComponent(mset, centers, opts);
expected_coeffs = mset.Size();
testCase.verifyEqual( comp.numCoeffs, uint32(expected_coeffs) );
end

Expand Down
7 changes: 3 additions & 4 deletions bindings/python/tests/test_MapFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,9 @@ def test_CreateSigmoidMaps():
sig = mpart.CreateSigmoidComponent(input_dim, max_degree, centers, opts)
expected_num_coeffs = math.comb(input_dim+max_degree, input_dim)
assert sig.numCoeffs == expected_num_coeffs
mset_diag = mpart.MultiIndexSet.CreateNonzeroDiagTotalOrder(input_dim, max_degree).fix(True)
mset_off = mpart.FixedMultiIndexSet(input_dim-1, max_degree)
sig_mset = mpart.CreateSigmoidComponent(mset_off, mset_diag, centers, opts)
assert sig_mset.numCoeffs == mset_diag.Size() + mset_off.Size()
mset = mpart.FixedMultiIndexSet(input_dim, max_degree)
sig_mset = mpart.CreateSigmoidComponent(mset, centers, opts)
assert sig_mset.numCoeffs == mset.Size()
output_dim = input_dim
centers_total = np.column_stack([centers for _ in range(output_dim)])
sig_trimap = mpart.CreateSigmoidTriangular(input_dim, output_dim, max_degree, centers_total, opts)
Expand Down
2 changes: 1 addition & 1 deletion src/MapFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ std::shared_ptr<ConditionalMapBase<MemorySpace>> CreateSigmoidExpansionTemplate(
return CreateSigmoidExpansionTemplate<MemorySpace, OffdiagEval, Rectifier, SigmoidType, EdgeType>(
mset, centers, edgeWidth);
}
MultiIndexSet mset = MultiIndexSet::CreateTotalOrder(inputDim, totalOrder, MultiIndexLimiter::NonzeroDiag());
MultiIndexSet mset = MultiIndexSet::CreateTotalOrder(inputDim, totalOrder);
FixedMultiIndexSet<MemorySpace> fmset_diag_d = mset.Fix(true).ToDevice<MemorySpace>();
return CreateSigmoidExpansionTemplate<MemorySpace, OffdiagEval, Rectifier, SigmoidType, EdgeType>(
fmset_diag_d, centers, edgeWidth);
Expand Down
Loading