Skip to content

Commit

Permalink
accelerated MVEW
Browse files Browse the repository at this point in the history
  • Loading branch information
MMRROOO committed Mar 20, 2024
1 parent e6a1806 commit a66b885
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
7 changes: 4 additions & 3 deletions MParT/MultivariateExpansionWorker.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,13 +479,14 @@ class MultivariateExpansionWorker
lastVal = polyCache[startPos_(d) + multiSet_.nzOrders(end_idx)];
}
if constexpr (!std::is_same_v<Rectifier, Identity>) {
bool isRectified = d == dim_ - 1;
if(wrt == dim_ - 1) { // Diagonal deriv
termVal = Rectifier::Evaluate(termVal)*wrtDeriv; // if wrt != d, wrtDeriv = 0
} else if (wrt == -1) { // No deriv
termVal = (d == dim_ - 1) ? Rectifier::Evaluate(termVal)*lastVal : termVal*lastVal;
termVal = (isRectified) ? Rectifier::Evaluate(termVal)*lastVal : termVal*lastVal;
} else { // Offdiag deriv
if(d != dim_ - 1) {
termVal *= wrtDeriv;
if(!isRectified) {
termVal *= lastVal*wrtDeriv;
}
else{
termVal = Rectifier::Derivative(termVal*wrtVal)*termVal*wrtDeriv*lastVal;
Expand Down
3 changes: 1 addition & 2 deletions tests/Test_MultivariateExpansionWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ TEMPLATE_TEST_CASE( "Testing multivariate expansion worker", "[MultivariateExpan
expansion.FillCache2(&cache[0], pt, pt(dim-1), DerivativeFlags::None);

eval2 = expansion.Evaluate(&cache[0], coeffs);
printf("%f, %f, %d, %d\n", inGrad(wrt), (eval2-eval)/fdStep, dim, wrt);
REQUIRE_THAT(inGrad(wrt), Matchers::WithinAbs((eval2-eval)/fdStep, fdStep*10));
CHECK_THAT(inGrad(wrt), Matchers::WithinAbs((eval2-eval)/fdStep, fdStep*10));
pt(wrt) -= fdStep;
}
}
Expand Down

0 comments on commit a66b885

Please sign in to comment.