diff --git a/MParT/MapFactory.h b/MParT/MapFactory.h index 162e1c39..73b90957 100644 --- a/MParT/MapFactory.h +++ b/MParT/MapFactory.h @@ -145,22 +145,22 @@ namespace mpart{ unsigned int inputDim, unsigned int totalOrder, Eigen::Ref centers, MapOptions opts) { StridedVector centersVec = ConstVecToKokkos(centers); - Kokkos::View centers_d = Kokkos::create_mirror_view_and_copy(MemorySpace(), centersVec); - return CreateSigmoidComponent(inputDim, totalOrder, centers_d, opts); + Kokkos::View centers = Kokkos::create_mirror_view_and_copy(MemorySpace(), centersVec); + return CreateSigmoidComponent(inputDim, totalOrder, centers, opts); } template std::shared_ptr> CreateSigmoidComponent( - FixedMultiIndexSet mset_diag, + FixedMultiIndexSet mset, StridedVector centers, MapOptions opts); template std::shared_ptr> CreateSigmoidComponent( - FixedMultiIndexSet mset_diag, + FixedMultiIndexSet mset, Eigen::Ref centers, MapOptions opts) { StridedVector centersVec = ConstVecToKokkos(centers); - Kokkos::View centers_d = Kokkos::create_mirror_view_and_copy(MemorySpace(), centersVec); - return CreateSigmoidComponent(mset_diag, centers_d, opts); + Kokkos::View centers = Kokkos::create_mirror_view_and_copy(MemorySpace(), centersVec); + return CreateSigmoidComponent(mset, centers, opts); } template diff --git a/bindings/julia/src/MapFactory.cpp b/bindings/julia/src/MapFactory.cpp index 4d6203c8..016900e2 100644 --- a/bindings/julia/src/MapFactory.cpp +++ b/bindings/julia/src/MapFactory.cpp @@ -19,9 +19,9 @@ void mpart::binding::MapFactoryWrapper(jlcxx::Module &mod) { }); // CreateSigmoidComponent - mod.method("CreateSigmoidComponent", [](FixedMultiIndexSet mset_diag, jlcxx::ArrayRef centers, MapOptions opts){ + mod.method("CreateSigmoidComponent", [](FixedMultiIndexSet mset, jlcxx::ArrayRef centers, MapOptions opts){ StridedVector centersVec = JuliaToKokkos(centers); - return MapFactory::CreateSigmoidComponent(mset_diag, centersVec, opts); + return MapFactory::CreateSigmoidComponent(mset, centersVec, opts); }); // CreateSigmoidTriangular diff --git a/src/MapFactory.cpp b/src/MapFactory.cpp index 23792f85..c648693d 100644 --- a/src/MapFactory.cpp +++ b/src/MapFactory.cpp @@ -232,13 +232,6 @@ std::shared_ptr> CreateSigmoidExpansionTemplate( StridedVector centers, double edgeWidth) { unsigned int inputDim = mset_diag.Length(); - // if(inputDim != 1 && inputDim != mset_offdiag.Length() + 1) { - // std::stringstream ss; - // ss << "Mismatched input dimensions for offdiag and diag multiindex sets\n" - // << "offdiag: " << mset_offdiag.Length() << "\n" - // << "diag: " << mset_diag.Length(); - // ProcAgnosticError(ss.str().c_str()); - // } using Sigmoid_T = Sigmoid1d; using Eval_T = BasisEvaluator, Rectifier>; auto sigmoid = CreateSigmoid(inputDim, centers, edgeWidth); diff --git a/tests/Test_MapFactory.cpp b/tests/Test_MapFactory.cpp index 6830a1f8..a0253af2 100644 --- a/tests/Test_MapFactory.cpp +++ b/tests/Test_MapFactory.cpp @@ -260,15 +260,14 @@ TEST_CASE( "Testing factory method for Sigmoid Component", "[MapFactorySigmoidCo } } SECTION("Create Sigmoid Component from fixed msets") { - // FixedMultiIndexSet mset_offdiag(inputDim-1, maxDegree); // Make some arbitrary limiter auto limiter = [maxDegree](MultiIndex const& index){ return index.Get(index.Length()-1) != 0 && index.Sum() == maxDegree; }; - FixedMultiIndexSet mset_diag = MultiIndexSet::CreateTotalOrder(inputDim, maxDegree, limiter).Fix(true); - std::shared_ptr> map = MapFactory::CreateSigmoidComponent(mset_diag, centers, options); + FixedMultiIndexSet mset= MultiIndexSet::CreateTotalOrder(inputDim, maxDegree, limiter).Fix(true); + std::shared_ptr> map = MapFactory::CreateSigmoidComponent(mset, centers, options); REQUIRE(map != nullptr); - REQUIRE(map->numCoeffs == mset_diag.Size()); + REQUIRE(map->numCoeffs == mset.Size()); } SECTION("Create Triangular Sigmoid Map From Components") { std::vector>> maps;