diff --git a/MParT/RectifiedMultivariateExpansion.h b/MParT/RectifiedMultivariateExpansion.h index 5e41e819..a471fcea 100644 --- a/MParT/RectifiedMultivariateExpansion.h +++ b/MParT/RectifiedMultivariateExpansion.h @@ -42,10 +42,9 @@ namespace mpart{ >; - RectifiedMultivariateExpansion(OffdiagWorker_T const& worker_off_, + RectifiedMultivariateExpansion(OffdiagWorker_T const& unused_worker_, Worker_T const& worker_diag_): ConditionalMapBase(worker_diag_.InputSize(), 1, worker_diag_.NumCoeffs()), - setSize(worker_diag_.NumCoeffs()), worker(worker_diag_) { //throw std::invalid_argument( "calling old constructor" ); @@ -53,7 +52,6 @@ namespace mpart{ RectifiedMultivariateExpansion(Worker_T const& worker_): ConditionalMapBase(worker_.InputSize(), 1, worker_.NumCoeffs()), - setSize(worker_.NumCoeffs()), worker(worker_) {}; @@ -67,7 +65,7 @@ namespace mpart{ // Take first dim-1 dimensions of pts and evaluate expansion_off // Add that to the evaluation of expansion_diag on pts StridedVector output_slice = Kokkos::subview(output, 0, Kokkos::ALL()); - StridedVector coeff = Coeff(); + StridedVector coeff = this->savedCoeffs; const unsigned int numPts = pts.extent(1); @@ -113,7 +111,7 @@ namespace mpart{ StridedVector sens_slice = Kokkos::subview(sens, 0, Kokkos::ALL()); unsigned int inDim = pts.extent(0); - StridedVector coeff = Coeff(); + StridedVector coeff = this->savedCoeffs; const unsigned int numPts = pts.extent(1); @@ -137,7 +135,7 @@ namespace mpart{ // Evaluate the expansion - // Fill in the entries in the cache dependent on x_d + // Fill in the entries in the cache worker.FillCache1(cache.data(), pt, DerivativeFlags::Input); worker.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::Input); @@ -187,7 +185,7 @@ namespace mpart{ Kokkos::View cache(team_member.thread_scratch(1), cacheSize); auto grad = Kokkos::subview(output, Kokkos::ALL(), ptInd); - // Fill in entries in the cache that are dependent on x_d. + // Fill in entries in the cache worker.FillCache1(cache.data(), pt, DerivativeFlags::Parameters); worker.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::Parameters); @@ -255,7 +253,7 @@ namespace mpart{ StridedVector out_slice = Kokkos::subview(output, 0, Kokkos::ALL()); StridedVector r_slice = Kokkos::subview(r, 0, Kokkos::ALL()); - StridedVector coeff = Coeff(); + StridedVector coeff = this->savedCoeffs; const unsigned int numPts = x1.extent(1); @@ -311,7 +309,7 @@ namespace mpart{ { unsigned int numPts = pts.extent(1); - StridedVector coeff = Coeff(); + StridedVector coeff = this->savedCoeffs; unsigned int cacheSize = worker.CacheSize(); // Take logdet of diagonal expansion @@ -359,7 +357,7 @@ namespace mpart{ // Take logdetcoeffgrad of diagonal expansion, output to bottom block - StridedVector coeff = Coeff(); + StridedVector coeff = this->savedCoeffs; unsigned int numPts = pts.extent(1); unsigned int cacheSize = worker.CacheSize(); @@ -423,9 +421,6 @@ namespace mpart{ Worker_T worker; - const unsigned int setSize; - StridedVector Coeff() const { return this->savedCoeffs; } - }; // class RectifiedMultivariateExpansion } diff --git a/tests/Test_RectifiedMultivariateExpansion.cpp b/tests/Test_RectifiedMultivariateExpansion.cpp index 12179a7d..52059e94 100644 --- a/tests/Test_RectifiedMultivariateExpansion.cpp +++ b/tests/Test_RectifiedMultivariateExpansion.cpp @@ -15,22 +15,13 @@ TEST_CASE("RectifiedMultivariateExpansion, Unrectified", "[RMVE_NoRect]") { unsigned int dim = 3; unsigned int maxOrder = 2; using T = ProbabilistHermite; - // Create a rectified MVE equivalent to a simple Hermite expansion - // using OffdiagEval_T = BasisEvaluator; - // using DiagEval_T = BasisEvaluator, Identity>; + using Eval_T = BasisEvaluator, Identity>; using RectExpansion_T = RectifiedMultivariateExpansion; - // BasisEvaluator basis_eval_offdiag; - // BasisEvaluator, Identity> basis_eval_diag{dim}; + Eval_T basis_eval{dim}; - //FixedMultiIndexSet fmset_offdiag(dim-1, maxOrder); - // auto limiter = MultiIndexLimiter::NonzeroDiag(); - // FixedMultiIndexSet fmset_diag = MultiIndexSet::CreateTotalOrder(dim, maxOrder, limiter).Fix(true); - // MultivariateExpansionWorker worker_off(fmset_offdiag, basis_eval_offdiag); - // MultivariateExpansionWorker worker_diag(fmset_diag, basis_eval_diag); - FixedMultiIndexSet fmset(dim, maxOrder); MultivariateExpansionWorker worker(fmset, basis_eval);