From 87e01971e49c7f03bd598aede1f8f9729d7e3211 Mon Sep 17 00:00:00 2001 From: MMRROOO <59205909+MMRROOO@users.noreply.github.com> Date: Wed, 20 Mar 2024 16:13:58 -0400 Subject: [PATCH] Rectified Multivariate Expansion Updated (#395) --- MParT/MultivariateExpansion.h | 2 +- MParT/MultivariateExpansionWorker.h | 15 ++++++++++----- tests/Test_MultivariateExpansionWorker.cpp | 3 +-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/MParT/MultivariateExpansion.h b/MParT/MultivariateExpansion.h index 3b240aed..2a38facf 100644 --- a/MParT/MultivariateExpansion.h +++ b/MParT/MultivariateExpansion.h @@ -221,4 +221,4 @@ namespace mpart{ } -#endif \ No newline at end of file +#endif diff --git a/MParT/MultivariateExpansionWorker.h b/MParT/MultivariateExpansionWorker.h index 80527250..be0a4ca0 100644 --- a/MParT/MultivariateExpansionWorker.h +++ b/MParT/MultivariateExpansionWorker.h @@ -432,7 +432,7 @@ class MultivariateExpansionWorker if(multiSet_.nzDims(end_idx)==dim_-1){ termVal = Rectifier::Evaluate(termVal)*lastVal; } else { - termVal = Rectifier::Evaluate(termVal*lastVal); + termVal = termVal*lastVal; } } else { termVal *= lastVal; @@ -484,13 +484,18 @@ class MultivariateExpansionWorker lastVal = polyCache[startPos_(d) + multiSet_.nzOrders(end_idx)]; } if constexpr (!std::is_same_v) { + bool isRectified = d == dim_ - 1; if(wrt == dim_ - 1) { // Diagonal deriv termVal = Rectifier::Evaluate(termVal)*wrtDeriv; // if wrt != d, wrtDeriv = 0 } else if (wrt == -1) { // No deriv - termVal = (d == dim_ - 1) ? Rectifier::Evaluate(termVal)*lastVal : Rectifier::Evaluate(termVal*lastVal); + termVal = (isRectified) ? Rectifier::Evaluate(termVal)*lastVal : termVal*lastVal; } else { // Offdiag deriv - if(d != dim_ - 1) wrtVal *= lastVal; // lastVal belongs on inside and outside - termVal = Rectifier::Derivative(termVal*wrtVal)*termVal*wrtDeriv*lastVal; + if(!isRectified) { + termVal *= lastVal*wrtDeriv; + } + else{ + termVal = Rectifier::Derivative(termVal*wrtVal)*termVal*wrtDeriv*lastVal; + } } } else { // Reduce to loop body termVal *= lastVal*wrtDeriv; @@ -568,4 +573,4 @@ class MultivariateExpansionWorker -#endif // #ifndef MPART_MULTIVARIATEEXPANSION_H \ No newline at end of file +#endif // #ifndef MPART_MULTIVARIATEEXPANSION_H diff --git a/tests/Test_MultivariateExpansionWorker.cpp b/tests/Test_MultivariateExpansionWorker.cpp index aba4519c..4eecfccd 100644 --- a/tests/Test_MultivariateExpansionWorker.cpp +++ b/tests/Test_MultivariateExpansionWorker.cpp @@ -179,8 +179,7 @@ TEMPLATE_TEST_CASE( "Testing multivariate expansion worker", "[MultivariateExpan expansion.FillCache2(&cache[0], pt, pt(dim-1), DerivativeFlags::None); eval2 = expansion.Evaluate(&cache[0], coeffs); - - REQUIRE_THAT(inGrad(wrt), Matchers::WithinAbs((eval2-eval)/fdStep, fdStep*10)); + CHECK_THAT(inGrad(wrt), Matchers::WithinAbs((eval2-eval)/fdStep, fdStep*10)); pt(wrt) -= fdStep; } }