Skip to content

Commit

Permalink
removed unnecessary code
Browse files Browse the repository at this point in the history
  • Loading branch information
MMRROOO committed Apr 3, 2024
1 parent 77c02e9 commit 5d65e99
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 24 deletions.
21 changes: 8 additions & 13 deletions MParT/RectifiedMultivariateExpansion.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,16 @@ namespace mpart{
>;


RectifiedMultivariateExpansion(OffdiagWorker_T const& worker_off_,
RectifiedMultivariateExpansion(OffdiagWorker_T const& unused_worker_,
Worker_T const& worker_diag_):
ConditionalMapBase<MemorySpace>(worker_diag_.InputSize(), 1, worker_diag_.NumCoeffs()),
setSize(worker_diag_.NumCoeffs()),
worker(worker_diag_)
{
//throw std::invalid_argument( "calling old constructor" );
};

RectifiedMultivariateExpansion(Worker_T const& worker_):
ConditionalMapBase<MemorySpace>(worker_.InputSize(), 1, worker_.NumCoeffs()),
setSize(worker_.NumCoeffs()),
worker(worker_)
{};

Expand All @@ -67,7 +65,7 @@ namespace mpart{
// Take first dim-1 dimensions of pts and evaluate expansion_off
// Add that to the evaluation of expansion_diag on pts
StridedVector<double, MemorySpace> output_slice = Kokkos::subview(output, 0, Kokkos::ALL());
StridedVector<const double, MemorySpace> coeff = Coeff();
StridedVector<const double, MemorySpace> coeff = this->savedCoeffs;

const unsigned int numPts = pts.extent(1);

Expand Down Expand Up @@ -113,7 +111,7 @@ namespace mpart{
StridedVector<const double, MemorySpace> sens_slice = Kokkos::subview(sens, 0, Kokkos::ALL());
unsigned int inDim = pts.extent(0);

StridedVector<const double, MemorySpace> coeff = Coeff();
StridedVector<const double, MemorySpace> coeff = this->savedCoeffs;

const unsigned int numPts = pts.extent(1);

Expand All @@ -137,7 +135,7 @@ namespace mpart{

// Evaluate the expansion

// Fill in the entries in the cache dependent on x_d
// Fill in the entries in the cache
worker.FillCache1(cache.data(), pt, DerivativeFlags::Input);
worker.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::Input);

Expand Down Expand Up @@ -187,7 +185,7 @@ namespace mpart{
Kokkos::View<double*,MemorySpace> cache(team_member.thread_scratch(1), cacheSize);
auto grad = Kokkos::subview(output, Kokkos::ALL(), ptInd);

// Fill in entries in the cache that are dependent on x_d.
// Fill in entries in the cache
worker.FillCache1(cache.data(), pt, DerivativeFlags::Parameters);
worker.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::Parameters);

Expand Down Expand Up @@ -255,7 +253,7 @@ namespace mpart{
StridedVector<double, MemorySpace> out_slice = Kokkos::subview(output, 0, Kokkos::ALL());
StridedVector<const double, MemorySpace> r_slice = Kokkos::subview(r, 0, Kokkos::ALL());

StridedVector<const double, MemorySpace> coeff = Coeff();
StridedVector<const double, MemorySpace> coeff = this->savedCoeffs;

const unsigned int numPts = x1.extent(1);

Expand Down Expand Up @@ -311,7 +309,7 @@ namespace mpart{
{

unsigned int numPts = pts.extent(1);
StridedVector<const double, MemorySpace> coeff = Coeff();
StridedVector<const double, MemorySpace> coeff = this->savedCoeffs;
unsigned int cacheSize = worker.CacheSize();

// Take logdet of diagonal expansion
Expand Down Expand Up @@ -359,7 +357,7 @@ namespace mpart{
// Take logdetcoeffgrad of diagonal expansion, output to bottom block


StridedVector<const double, MemorySpace> coeff = Coeff();
StridedVector<const double, MemorySpace> coeff = this->savedCoeffs;
unsigned int numPts = pts.extent(1);
unsigned int cacheSize = worker.CacheSize();

Expand Down Expand Up @@ -423,9 +421,6 @@ namespace mpart{


Worker_T worker;
const unsigned int setSize;
StridedVector<const double, MemorySpace> Coeff() const { return this->savedCoeffs; }

}; // class RectifiedMultivariateExpansion
}

Expand Down
13 changes: 2 additions & 11 deletions tests/Test_RectifiedMultivariateExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,13 @@ TEST_CASE("RectifiedMultivariateExpansion, Unrectified", "[RMVE_NoRect]") {
unsigned int dim = 3;
unsigned int maxOrder = 2;
using T = ProbabilistHermite;
// Create a rectified MVE equivalent to a simple Hermite expansion
// using OffdiagEval_T = BasisEvaluator<BasisHomogeneity::Homogeneous, T>;
// using DiagEval_T = BasisEvaluator<BasisHomogeneity::OffdiagHomogeneous, Kokkos::pair<T, T>, Identity>;

using Eval_T = BasisEvaluator<BasisHomogeneity::OffdiagHomogeneous, Kokkos::pair<T, T>, Identity>;

using RectExpansion_T = RectifiedMultivariateExpansion<MemorySpace, T, T, Identity>;
// BasisEvaluator<BasisHomogeneity::Homogeneous, T> basis_eval_offdiag;
// BasisEvaluator<BasisHomogeneity::OffdiagHomogeneous, Kokkos::pair<T, T>, Identity> basis_eval_diag{dim};

Eval_T basis_eval{dim};

//FixedMultiIndexSet<MemorySpace> fmset_offdiag(dim-1, maxOrder);
// auto limiter = MultiIndexLimiter::NonzeroDiag();
// FixedMultiIndexSet<MemorySpace> fmset_diag = MultiIndexSet::CreateTotalOrder(dim, maxOrder, limiter).Fix(true);
// MultivariateExpansionWorker<OffdiagEval_T, MemorySpace> worker_off(fmset_offdiag, basis_eval_offdiag);
// MultivariateExpansionWorker<DiagEval_T, MemorySpace> worker_diag(fmset_diag, basis_eval_diag);

FixedMultiIndexSet<MemorySpace> fmset(dim, maxOrder);
MultivariateExpansionWorker<Eval_T, MemorySpace> worker(fmset, basis_eval);

Expand Down

0 comments on commit 5d65e99

Please sign in to comment.