Skip to content

Commit

Permalink
Rectified Multivariate Expansion Updated
Browse files Browse the repository at this point in the history
  • Loading branch information
MMRROOO committed Mar 20, 2024
1 parent c6628cd commit 6e852ff
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion MParT/MultivariateExpansion.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,4 @@ namespace mpart{
}


#endif
#endif
15 changes: 10 additions & 5 deletions MParT/MultivariateExpansionWorker.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ class MultivariateExpansionWorker
if(multiSet_.nzDims(end_idx)==dim_-1){
termVal = Rectifier::Evaluate(termVal)*lastVal;
} else {
termVal = Rectifier::Evaluate(termVal*lastVal);
termVal = termVal*lastVal;
}
} else {
termVal *= lastVal;
Expand Down Expand Up @@ -484,13 +484,18 @@ class MultivariateExpansionWorker
lastVal = polyCache[startPos_(d) + multiSet_.nzOrders(end_idx)];
}
if constexpr (!std::is_same_v<Rectifier, Identity>) {
bool isRectified = d == dim_ - 1;
if(wrt == dim_ - 1) { // Diagonal deriv
termVal = Rectifier::Evaluate(termVal)*wrtDeriv; // if wrt != d, wrtDeriv = 0
} else if (wrt == -1) { // No deriv
termVal = (d == dim_ - 1) ? Rectifier::Evaluate(termVal)*lastVal : Rectifier::Evaluate(termVal*lastVal);
termVal = (isRectified) ? Rectifier::Evaluate(termVal)*lastVal : termVal*lastVal;
} else { // Offdiag deriv
if(d != dim_ - 1) wrtVal *= lastVal; // lastVal belongs on inside and outside
termVal = Rectifier::Derivative(termVal*wrtVal)*termVal*wrtDeriv*lastVal;
if(!isRectified) {
termVal *= lastVal*wrtDeriv;
}
else{
termVal = Rectifier::Derivative(termVal*wrtVal)*termVal*wrtDeriv*lastVal;
}
}
} else { // Reduce to loop body
termVal *= lastVal*wrtDeriv;
Expand Down Expand Up @@ -568,4 +573,4 @@ class MultivariateExpansionWorker



#endif // #ifndef MPART_MULTIVARIATEEXPANSION_H
#endif // #ifndef MPART_MULTIVARIATEEXPANSION_H
3 changes: 1 addition & 2 deletions tests/Test_MultivariateExpansionWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ TEMPLATE_TEST_CASE( "Testing multivariate expansion worker", "[MultivariateExpan
expansion.FillCache2(&cache[0], pt, pt(dim-1), DerivativeFlags::None);

eval2 = expansion.Evaluate(&cache[0], coeffs);

REQUIRE_THAT(inGrad(wrt), Matchers::WithinAbs((eval2-eval)/fdStep, fdStep*10));
CHECK_THAT(inGrad(wrt), Matchers::WithinAbs((eval2-eval)/fdStep, fdStep*10));
pt(wrt) -= fdStep;
}
}
Expand Down

0 comments on commit 6e852ff

Please sign in to comment.