diff --git a/include/ode_solvers/oracle_functors.hpp b/include/ode_solvers/oracle_functors.hpp index 3aea450b2..35fc8887b 100644 --- a/include/ode_solvers/oracle_functors.hpp +++ b/include/ode_solvers/oracle_functors.hpp @@ -243,7 +243,7 @@ struct IsotropicLinearFunctor { }; -struct LinearProgramFunctor { +struct ExponentialFunctor { // Sample from linear program c^T x (exponential density) template < @@ -256,8 +256,10 @@ struct LinearProgramFunctor { NT m; // Strong convexity constant NT kappa; // Condition number Point c; // Coefficients of LP objective + NT a; // Inverse variance - parameters(Point c_) : order(2), L(1), m(1), kappa(1), c(c_) {}; + parameters(Point c_) : order(2), L(1), m(1), kappa(1), c(c_), a(1.0) {}; + parameters(Point c_, NT a_) : order(2), L(1), m(1), kappa(1), c(c_), a(a_) {}; }; @@ -277,7 +279,7 @@ struct LinearProgramFunctor { Point operator() (unsigned int const& i, pts const& xs, NT const& t) const { if (i == params.order - 1) { Point y(params.c); - return (-1.0) * y; + return (-params.a) * y; } else { return xs[i + 1]; // returns derivative } @@ -298,7 +300,7 @@ struct LinearProgramFunctor { // The index i represents the state vector index NT operator() (Point const& x) const { - return x.dot(params.c); + return params.a * x.dot(params.c); } };