Skip to content

Commit

Permalink
updated MultivariateExpansionWorker, tests not working though
Browse files Browse the repository at this point in the history
  • Loading branch information
MMRROOO committed Mar 17, 2024
1 parent 9c5c0b1 commit e6a1806
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion MParT/MultivariateExpansion.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,4 @@ namespace mpart{
}


#endif
#endif
14 changes: 9 additions & 5 deletions MParT/MultivariateExpansionWorker.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ class MultivariateExpansionWorker
if(multiSet_.nzDims(end_idx)==dim_-1){
termVal = Rectifier::Evaluate(termVal)*lastVal;
} else {
termVal = Rectifier::Evaluate(termVal*lastVal);
termVal = termVal*lastVal;
}
} else {
termVal *= lastVal;
Expand Down Expand Up @@ -482,10 +482,14 @@ class MultivariateExpansionWorker
if(wrt == dim_ - 1) { // Diagonal deriv
termVal = Rectifier::Evaluate(termVal)*wrtDeriv; // if wrt != d, wrtDeriv = 0
} else if (wrt == -1) { // No deriv
termVal = (d == dim_ - 1) ? Rectifier::Evaluate(termVal)*lastVal : Rectifier::Evaluate(termVal*lastVal);
termVal = (d == dim_ - 1) ? Rectifier::Evaluate(termVal)*lastVal : termVal*lastVal;
} else { // Offdiag deriv
if(d != dim_ - 1) wrtVal *= lastVal; // lastVal belongs on inside and outside
termVal = Rectifier::Derivative(termVal*wrtVal)*termVal*wrtDeriv*lastVal;
if(d != dim_ - 1) {
termVal *= wrtDeriv;
}
else{
termVal = Rectifier::Derivative(termVal*wrtVal)*termVal*wrtDeriv*lastVal;
}
}
} else { // Reduce to loop body
termVal *= lastVal*wrtDeriv;
Expand Down Expand Up @@ -562,4 +566,4 @@ class MultivariateExpansionWorker



#endif // #ifndef MPART_MULTIVARIATEEXPANSION_H
#endif // #ifndef MPART_MULTIVARIATEEXPANSION_H
2 changes: 1 addition & 1 deletion tests/Test_MultivariateExpansionWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ TEMPLATE_TEST_CASE( "Testing multivariate expansion worker", "[MultivariateExpan
expansion.FillCache2(&cache[0], pt, pt(dim-1), DerivativeFlags::None);

eval2 = expansion.Evaluate(&cache[0], coeffs);

printf("%f, %f, %d, %d\n", inGrad(wrt), (eval2-eval)/fdStep, dim, wrt);
REQUIRE_THAT(inGrad(wrt), Matchers::WithinAbs((eval2-eval)/fdStep, fdStep*10));
pt(wrt) -= fdStep;
}
Expand Down

0 comments on commit e6a1806

Please sign in to comment.