Skip to content

Commit

Permalink
Simplify fvar framework
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed Aug 8, 2023
1 parent 3435d1b commit 77e4812
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 21 deletions.
22 changes: 6 additions & 16 deletions stan/math/fwd/functor/fvar_finite_diff.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#define STAN_MATH_FWD_FUNCTOR_FVAR_FINITE_DIFF_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/functor/apply.hpp>
#include <stan/math/prim/functor/finite_diff_gradient_auto.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/fun/sum.hpp>
Expand Down Expand Up @@ -66,34 +65,25 @@ auto fvar_finite_diff(const F& func, const TArgs&... args) {
using FvarT = return_type_t<TArgs...>;
using FvarInnerT = typename FvarT::Scalar;

auto val_args = std::make_tuple(stan::math::value_of(args)...);

auto serialised_args = stan::math::apply(
[&](auto&&... tuple_args) {
return math::to_vector(
stan::test::serialize<FvarInnerT>(tuple_args...));
},
val_args);
auto serialised_args = test::serialize<FvarInnerT>(value_of(args)...);

// Create a 'wrapper' functor which will take the flattened column-vector
// and transform it to individual arguments which are passed to the
// user-provided functor
auto serial_functor = [&](const auto& v) {
auto ds = stan::test::to_deserializer(v);
return stan::math::apply(
[&](auto&&... tuple_args) { return func(ds.read(tuple_args)...); },
val_args);
return func(test::to_deserializer(v).read(args)...);
};

FvarInnerT rtn_value;
Eigen::Matrix<FvarInnerT, -1, 1> grad;
finite_diff_gradient_auto(serial_functor, serialised_args, rtn_value, grad);
finite_diff_gradient_auto(serial_functor, to_vector(serialised_args),
rtn_value, grad);

auto ds_grad = stan::test::to_deserializer(grad);
FvarInnerT rtn_grad = 0;
// Use a fold-expression to aggregate tangents for input arguments
(void)std::initializer_list<int>{(
rtn_grad += internal::aggregate_tangent(ds_grad.read(args), args), 0)...};
rtn_grad += internal::aggregate_tangent(
test::to_deserializer(grad).read(args), args), 0)...};

return FvarT(rtn_value, rtn_grad);
}
Expand Down
10 changes: 5 additions & 5 deletions test/unit/math/mix/functor/integrate_1d_test.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
#include <test/unit/math/test_ad.hpp>

TEST(mixFunctor, integrate1D) {
auto f = [&](const auto& x_input) {
auto f = [&](const auto& x_input, const auto& lb, const auto& ub) {
auto func = [](const auto& x, const auto& xc, std::ostream* msgs,
const auto& theta) {
return stan::math::exp(theta * stan::math::cos(2 * 3.141593 * x)) + theta;
};
const double relative_tolerance = std::sqrt(stan::math::EPSILON);
std::ostringstream* msgs = nullptr;
return stan::math::integrate_1d_impl(func, 0, 1, relative_tolerance, msgs,
return stan::math::integrate_1d_impl(func, lb, ub, relative_tolerance, msgs,
x_input);
};
stan::test::expect_ad(f, 0.75);
stan::test::expect_ad(f, 0.2);
stan::test::expect_ad(f, stan::math::NOT_A_NUMBER);
stan::test::expect_ad(f, 0.75, 0, 1);
stan::test::expect_ad(f, 0.2, 0.2, 0.7);
stan::test::expect_ad(f, stan::math::NOT_A_NUMBER, 0, 1);
}

0 comments on commit 77e4812

Please sign in to comment.