Skip to content

Commit

Permalink
working RMVE tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MMRROOO committed Apr 3, 2024
1 parent cd89b54 commit 77c02e9
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 102 deletions.
142 changes: 58 additions & 84 deletions MParT/RectifiedMultivariateExpansion.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,25 @@ namespace mpart{
MemorySpace
>;
using Worker_T = MultivariateExpansionWorker<
BasisEvaluator<BasisHomogeneity::Homogeneous, OffdiagEval>, //for now pass in full eval as OffDiagEval
MemorySpace
BasisEvaluator<BasisHomogeneity::OffdiagHomogeneous,
Kokkos::pair<OffdiagEval, DiagEval>, //OffDiag and DiagEval are the same
Rectifier>, MemorySpace
>;


RectifiedMultivariateExpansion(OffdiagWorker_T const& worker_off_,
DiagWorker_T const& worker_diag_):
ConditionalMapBase<MemorySpace>(worker_diag_.InputSize(), 1, worker_off_.NumCoeffs() + worker_diag_.NumCoeffs()),
setSize_off(worker_off_.NumCoeffs()),
setSize_diag(worker_diag_.NumCoeffs()),
setSize(0),
worker_off(worker_off_),
worker_diag(worker_diag_)
{throw std::invalid_argument( "calling old constructor" );};
Worker_T const& worker_diag_):
ConditionalMapBase<MemorySpace>(worker_diag_.InputSize(), 1, worker_diag_.NumCoeffs()),
setSize(worker_diag_.NumCoeffs()),
worker(worker_diag_)
{
//throw std::invalid_argument( "calling old constructor" );
};

RectifiedMultivariateExpansion(Worker_T const& worker_):
ConditionalMapBase<MemorySpace>(worker_.InputSize(), 1, worker_.NumCoeffs()),
setSize(worker_.NumCoeffs()),
worker(worker_),
setSize_off(0),
setSize_diag(0)
worker(worker_)
{};


Expand Down Expand Up @@ -134,7 +132,7 @@ namespace mpart{

// Get a pointer to the shared memory that Kokkos set up for this team
Kokkos::View<double*,MemorySpace> cache(team_member.thread_scratch(1), cacheSize);
Kokkos::View<double*,MemorySpace> grad(team_member.thread_scratch(1), inDim-1);
StridedVector<double, MemorySpace> grad = Kokkos::subview(output, Kokkos::ALL(), ptInd);


// Evaluate the expansion
Expand All @@ -144,10 +142,12 @@ namespace mpart{
worker.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::Input);

// Evaluate the expansion
worker_diag.InputDerivative(cache.data(), coeff, grad);
worker.InputDerivative(cache.data(), coeff, grad);

for(unsigned int i=0; i<inDim; ++i) {
grad(i) *= sens_slice(ptInd);
}

grad(inDim - 1) = sens_slice(ptInd) * grad(inDim - 1);
}
};

Expand All @@ -165,17 +165,14 @@ namespace mpart{
StridedMatrix<const double, MemorySpace> const& sens,
StridedMatrix<double, MemorySpace> output) override
{
StridedVector<const double, MemorySpace> coeff_off = CoeffOff();
StridedVector<const double, MemorySpace> coeff_diag = CoeffDiag();
StridedVector<const double, MemorySpace> coeff = this->savedCoeffs;
StridedVector<const double, MemorySpace> sens_slice = Kokkos::subview(sens, 0, Kokkos::ALL());

const unsigned int numPts = pts.extent(1);

// Figure out how much memory we'll need in the cache
unsigned int cacheSize = std::max(worker_diag.CacheSize(), worker_off.CacheSize());
unsigned int maxParams = coeff_off.size() + coeff_diag.size();
Kokkos::pair<unsigned int, unsigned int> coeff_off_idx {0u,(unsigned int)coeff_off.size()};
Kokkos::pair<unsigned int, unsigned int> coeff_diag_idx {(unsigned int) coeff_off.size(), maxParams};
unsigned int cacheSize = worker.CacheSize();
unsigned int maxParams = coeff.size();

auto functor = KOKKOS_CLASS_LAMBDA (typename Kokkos::TeamPolicy<ExecutionSpace>::member_type team_member) {

Expand All @@ -185,26 +182,17 @@ namespace mpart{

// Create a subview containing only the current point
auto pt = Kokkos::subview(pts, Kokkos::ALL(), ptInd);
auto pt_off = Kokkos::subview(pt, std::pair<int,int>(0,pt.extent(0)-1));

// Get a pointer to the shared memory that Kokkos set up for this team
Kokkos::View<double*,MemorySpace> cache(team_member.thread_scratch(1), cacheSize);
auto grad_off = Kokkos::subview(output, coeff_off_idx, ptInd);
auto grad_diag = Kokkos::subview(output, coeff_diag_idx, ptInd);

// Fill in entries in the cache that are independent of x_d.
worker_off.FillCache1(cache.data(), pt_off, DerivativeFlags::Parameters);
worker_off.FillCache2(cache.data(), pt_off, pt_off(pt_off.size()-1), DerivativeFlags::Parameters);

// Evaluate the expansion
worker_off.CoeffDerivative(cache.data(), coeff_off, grad_off);
auto grad = Kokkos::subview(output, Kokkos::ALL(), ptInd);

// Fill in entries in the cache that are dependent on x_d.
worker_diag.FillCache1(cache.data(), pt, DerivativeFlags::Parameters);
worker_diag.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::Parameters);
worker.FillCache1(cache.data(), pt, DerivativeFlags::Parameters);
worker.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::Parameters);

// Evaluate the expansion
worker_diag.CoeffDerivative(cache.data(), coeff_diag, grad_diag);
worker.CoeffDerivative(cache.data(), coeff, grad);

// TODO: Move this into own kernel?
for(unsigned int i=0; i<maxParams; ++i)
Expand All @@ -225,8 +213,8 @@ namespace mpart{
{
// Take logdet of diagonal expansion
unsigned int numPts = pts.extent(1);
StridedVector<const double, MemorySpace> coeff_diag = CoeffDiag();
unsigned int cacheSize = worker_diag.CacheSize();
StridedVector<const double, MemorySpace> coeff = this->savedCoeffs;
unsigned int cacheSize = worker.CacheSize();

// Take logdet of diagonal expansion
auto functor = KOKKOS_LAMBDA (typename Kokkos::TeamPolicy<ExecutionSpace>::member_type team_member) {
Expand All @@ -242,10 +230,10 @@ namespace mpart{
Kokkos::View<double*,MemorySpace> cache(team_member.thread_scratch(1), cacheSize);

// Fill in entries in the cache that are independent of x_d. By passing DerivativeFlags::None, we are telling the expansion that no derivatives with wrt x_1,...x_{d-1} will be needed.
worker_diag.FillCache1(cache.data(), pt, DerivativeFlags::None);
worker_diag.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::Diagonal);
worker.FillCache1(cache.data(), pt, DerivativeFlags::None);
worker.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::Diagonal);
// Evaluate the expansion
output(ptInd) = Kokkos::log(worker_diag.DiagonalDerivative(cache.data(), coeff_diag, 1));
output(ptInd) = Kokkos::log(worker.DiagonalDerivative(cache.data(), coeff, 1));
}
};

Expand All @@ -267,13 +255,12 @@ namespace mpart{
StridedVector<double, MemorySpace> out_slice = Kokkos::subview(output, 0, Kokkos::ALL());
StridedVector<const double, MemorySpace> r_slice = Kokkos::subview(r, 0, Kokkos::ALL());

StridedVector<const double, MemorySpace> coeff_off = CoeffOff();
StridedVector<const double, MemorySpace> coeff_diag = CoeffDiag();
StridedVector<const double, MemorySpace> coeff = Coeff();

const unsigned int numPts = x1.extent(1);

// Figure out how much memory we'll need in the cache
unsigned int cacheSize = std::max(worker_diag.CacheSize(), worker_off.CacheSize());
unsigned int cacheSize = worker.CacheSize();

// Options for root finding
const double xtol = 1e-6, ytol = 1e-6;
Expand All @@ -299,15 +286,13 @@ namespace mpart{
Kokkos::View<double*,MemorySpace> cache(team_member.thread_scratch(1), cacheSize);

// Fill in entries in the cache that are independent of x_d.
worker_off.FillCache1(cache.data(), pt, DerivativeFlags::None);
worker_off.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::None);


// Note r = g(x) + f(x,y) --> y = f(x,.)^{-1}(r - g(x))
double yd = r_slice(ptInd) - worker_off.Evaluate(cache.data(), coeff_off);
double yd = r_slice(ptInd);

// Fill in entries in the cache that are independent on x_d.
worker_diag.FillCache1(cache.data(), pt, DerivativeFlags::None);
SingleWorkerEvaluator<decltype(pt), decltype(coeff_diag)> evaluator {cache.data(), pt, coeff_diag, worker_diag};
worker.FillCache1(cache.data(), pt, DerivativeFlags::None);
SingleWorkerEvaluator<decltype(pt), decltype(coeff)> evaluator {cache.data(), pt, coeff, worker};
out_slice(ptInd) = RootFinding::InverseSingleBracket<MemorySpace>(yd, evaluator, pt(pt.size()-1), xtol, ytol, info);
}
};
Expand All @@ -326,8 +311,8 @@ namespace mpart{
{

unsigned int numPts = pts.extent(1);
StridedVector<const double, MemorySpace> coeff_diag = CoeffDiag();
unsigned int cacheSize = worker_diag.CacheSize();
StridedVector<const double, MemorySpace> coeff = Coeff();
unsigned int cacheSize = worker.CacheSize();

// Take logdet of diagonal expansion
auto functor = KOKKOS_LAMBDA (typename Kokkos::TeamPolicy<ExecutionSpace>::member_type team_member) {
Expand All @@ -344,16 +329,16 @@ namespace mpart{
Kokkos::View<double*,MemorySpace> cache(team_member.thread_scratch(1), cacheSize);

// Fill in the mixed entries grad_{x,y}d_y
worker_diag.FillCache1(cache.data(), pt, DerivativeFlags::MixedInput);
worker_diag.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::MixedInput);
worker.FillCache1(cache.data(), pt, DerivativeFlags::MixedInput);
worker.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::MixedInput);

// Evaluate the expansion for mixed derivative
worker_diag.MixedInputDerivative(cache.data(), coeff_diag, out);
worker.MixedInputDerivative(cache.data(), coeff, out);

// Find diagonal derivative d_y
worker_diag.FillCache1(cache.data(), pt, DerivativeFlags::Diagonal);
worker_diag.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::Diagonal);
double diag_deriv = worker_diag.DiagonalDerivative(cache.data(), coeff_diag, 1);
worker.FillCache1(cache.data(), pt, DerivativeFlags::Diagonal);
worker.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::Diagonal);
double diag_deriv = worker.DiagonalDerivative(cache.data(), coeff, 1);

// grad_{x,y} log(d_y T(x,y)) = [grad_x d_y T(x,y), d_y^2 T(x,y)] / d_y T(x,y)
for(unsigned int ii=0; ii<out.size(); ++ii) out(ii) /= diag_deriv;
Expand All @@ -372,15 +357,11 @@ namespace mpart{
StridedMatrix<double, MemorySpace> output) override
{
// Take logdetcoeffgrad of diagonal expansion, output to bottom block
StridedMatrix<double, MemorySpace> output_off = Kokkos::subview(output,
std::make_pair(0u,worker_off.NumCoeffs()), Kokkos::ALL());
Kokkos::deep_copy(output_off, 0.0);
StridedMatrix<double, MemorySpace> output_diag = Kokkos::subview(output,
std::make_pair(worker_off.NumCoeffs(),worker_off.NumCoeffs()+worker_diag.NumCoeffs()),
Kokkos::ALL());
StridedVector<const double, MemorySpace> coeff_diag = CoeffDiag();


StridedVector<const double, MemorySpace> coeff = Coeff();
unsigned int numPts = pts.extent(1);
unsigned int cacheSize = worker_diag.CacheSize();
unsigned int cacheSize = worker.CacheSize();

// Take logdet of diagonal expansion
auto functor = KOKKOS_LAMBDA (typename Kokkos::TeamPolicy<ExecutionSpace>::member_type team_member) {
Expand All @@ -391,22 +372,22 @@ namespace mpart{

// Create a subview containing only the current point
auto pt = Kokkos::subview(pts, Kokkos::ALL(), ptInd);
auto out = Kokkos::subview(output_diag, Kokkos::ALL(), ptInd);
auto out = Kokkos::subview(output, Kokkos::ALL(), ptInd);

// Get a pointer to the shared memory that Kokkos set up for this team
Kokkos::View<double*,MemorySpace> cache(team_member.thread_scratch(1), cacheSize);

// Fill in cache with mixed entries grad_c d_y T(x,y; c)
worker_diag.FillCache1(cache.data(), pt, DerivativeFlags::MixedCoeff);
worker_diag.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::MixedCoeff);
worker.FillCache1(cache.data(), pt, DerivativeFlags::MixedCoeff);
worker.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::MixedCoeff);

// Evaluate the Mixed coeff derivatives
worker_diag.MixedCoeffDerivative(cache.data(), coeff_diag, 1, out);
worker.MixedCoeffDerivative(cache.data(), coeff, 1, out);

// Find diagonal derivative d_y
worker_diag.FillCache1(cache.data(), pt, DerivativeFlags::Diagonal);
worker_diag.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::Diagonal);
double diag_deriv = worker_diag.DiagonalDerivative(cache.data(), coeff_diag, 1);
worker.FillCache1(cache.data(), pt, DerivativeFlags::Diagonal);
worker.FillCache2(cache.data(), pt, pt(pt.size()-1), DerivativeFlags::Diagonal);
double diag_deriv = worker.DiagonalDerivative(cache.data(), coeff, 1);
for(unsigned int ii=0; ii<out.size(); ++ii) out(ii) /= diag_deriv;
}
};
Expand All @@ -421,9 +402,7 @@ namespace mpart{

std::vector<unsigned int> DiagonalCoeffIndices() const
{
std::vector<unsigned int> diagIndices(setSize_diag);
std::iota(diagIndices.begin(), diagIndices.end(), setSize_off);
return diagIndices;
return worker.NonzeroDiagonalEntries();
}

private:
Expand All @@ -432,25 +411,20 @@ namespace mpart{
double* cache;
PointType pt;
CoeffType coeffs;
DiagWorker_T worker;
Worker_T worker;

SingleWorkerEvaluator(double* cache_, PointType pt_, CoeffType coeffs_, DiagWorker_T worker_):
SingleWorkerEvaluator(double* cache_, PointType pt_, CoeffType coeffs_, Worker_T worker_):
cache(cache_), pt(pt_), coeffs(coeffs_), worker(worker_) {}
double operator()(double x) {
worker.FillCache2(cache, pt, x, DerivativeFlags::None);
return worker.Evaluate(cache, coeffs);
}
};

OffdiagWorker_T worker_off;
DiagWorker_T worker_diag;

Worker_T worker;
const unsigned int setSize_off;
const unsigned int setSize_diag;
const unsigned int setSize;
StridedVector<const double, MemorySpace> CoeffOff() const { return Kokkos::subview(this->savedCoeffs, std::make_pair(0u, setSize_off)); }
StridedVector<const double, MemorySpace> CoeffDiag() const { return Kokkos::subview(this->savedCoeffs, std::make_pair(setSize_off, setSize_off+setSize_diag)); }
StridedVector<const double, MemorySpace> Coeff() const { return Kokkos::subview(this->savedCoeffs, std::make_pair(0u, setSize)); }
StridedVector<const double, MemorySpace> Coeff() const { return this->savedCoeffs; }

}; // class RectifiedMultivariateExpansion
}
Expand Down
2 changes: 1 addition & 1 deletion tests/Test_MapFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ TEST_CASE( "Testing factory method for Sigmoid Component", "[MapFactorySigmoidCo
FixedMultiIndexSet<MemorySpace> mset_diag = MultiIndexSet::CreateTotalOrder(inputDim, maxDegree, limiter).Fix(true);
std::shared_ptr<ConditionalMapBase<MemorySpace>> map = MapFactory::CreateSigmoidComponent<MemorySpace>(mset_offdiag, mset_diag, centers, options);
REQUIRE(map != nullptr);
REQUIRE(map->numCoeffs == mset_diag.Size()+mset_offdiag.Size());
REQUIRE(map->numCoeffs == mset_diag.Size());
}
SECTION("Create Triangular Sigmoid Map From Components") {
std::vector<std::shared_ptr<ConditionalMapBase<MemorySpace>>> maps;
Expand Down
Loading

0 comments on commit 77c02e9

Please sign in to comment.