Skip to content

Commit

Permalink
Merge pull request #2929 from stan-dev/fvar-support
Browse files Browse the repository at this point in the history
Framework for generic fvar<T> support through finite-differences
  • Loading branch information
andrjohns authored Oct 4, 2023
2 parents 11b3aff + ad303aa commit efbc688
Show file tree
Hide file tree
Showing 11 changed files with 291 additions and 38 deletions.
2 changes: 2 additions & 0 deletions stan/math/fwd/functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

#include <stan/math/fwd/functor/apply_scalar_unary.hpp>
#include <stan/math/fwd/functor/gradient.hpp>
#include <stan/math/fwd/functor/finite_diff.hpp>
#include <stan/math/fwd/functor/hessian.hpp>
#include <stan/math/fwd/functor/integrate_1d.hpp>
#include <stan/math/fwd/functor/jacobian.hpp>
#include <stan/math/fwd/functor/operands_and_partials.hpp>
#include <stan/math/fwd/functor/partials_propagator.hpp>
Expand Down
120 changes: 120 additions & 0 deletions stan/math/fwd/functor/finite_diff.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#ifndef STAN_MATH_FWD_FUNCTOR_FINITE_DIFF_HPP
#define STAN_MATH_FWD_FUNCTOR_FINITE_DIFF_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
#include <stan/math/prim/functor/finite_diff_gradient_auto.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/prim/fun/serializer.hpp>

namespace stan {
namespace math {
namespace internal {
/**
* Helper function for aggregating tangents if the respective input argument
* was an fvar<T> type.
*
* Overload for when the input is not an fvar<T> and no tangents are needed.
*
* @tparam FuncTangent Type of tangent calculated by finite-differences
* @tparam InputArg Type of the function input argument
* @param tangent Calculated tangent
* @param arg Input argument
*/
template <typename FuncTangent, typename InputArg,
require_not_st_fvar<InputArg>* = nullptr>
inline constexpr double aggregate_tangent(const FuncTangent& tangent,
const InputArg& arg) {
return 0;
}

/**
* Helper function for aggregating tangents if the respective input argument
* was an fvar<T> type.
*
* Overload for when the input is an fvar<T> and its tangent needs to be
* aggregated.
*
* @tparam FuncTangent Type of tangent calculated by finite-differences
* @tparam InputArg Type of the function input argument
* @param tangent Calculated tangent
* @param arg Input argument
*/
template <typename FuncTangent, typename InputArg,
require_st_fvar<InputArg>* = nullptr>
inline auto aggregate_tangent(const FuncTangent& tangent, const InputArg& arg) {
return sum(apply_scalar_binary(
tangent, arg, [](const auto& x, const auto& y) { return x * y.d_; }));
}
} // namespace internal

/**
* Construct an fvar<T> where the tangent is calculated by finite-differencing.
* Finite-differencing is only perfomed where the scalar type to be evaluated is
* `fvar<T>.
*
* Higher-order inputs (i.e., fvar<var> & fvar<fvar<T>>) are also implicitly
* supported through auto-diffing the finite-differencing process.
*
* @tparam F Type of functor for which fvar<T> support is needed
* @tparam TArgs Template parameter pack of the types passed in the `operator()`
* of the functor type `F`. Must contain at least on type whose
* scalar type is `fvar<T>`
* @param func Functor for which fvar<T> support is needed
* @param args Parameter pack of arguments to be passed to functor.
*/
template <typename F, typename... TArgs,
require_any_st_fvar<TArgs...>* = nullptr>
inline auto finite_diff(const F& func, const TArgs&... args) {
using FvarT = return_type_t<TArgs...>;
using FvarInnerT = typename FvarT::Scalar;

std::vector<FvarInnerT> serialised_args
= serialize<FvarInnerT>(value_of(args)...);

auto serial_functor = [&](const auto& v) {
auto v_deserializer = to_deserializer(v);
return func(v_deserializer.read(args)...);
};

FvarInnerT rtn_value;
std::vector<FvarInnerT> grad;
finite_diff_gradient_auto(serial_functor, serialised_args, rtn_value, grad);

FvarInnerT rtn_grad = 0;
auto grad_deserializer = to_deserializer(grad);
// Use a fold-expression to aggregate tangents for input arguments
static_cast<void>(
std::initializer_list<int>{(rtn_grad += internal::aggregate_tangent(
grad_deserializer.read(args), args),
0)...});

return FvarT(rtn_value, rtn_grad);
}

/**
* Construct an fvar<T> where the tangent is calculated by finite-differencing.
* Finite-differencing is only perfomed where the scalar type to be evaluated is
* `fvar<T>.
*
* This overload is used when no fvar<T> arguments are passed and simply
* evaluates the functor with the provided arguments.
*
* @tparam F Type of functor
* @tparam TArgs Template parameter pack of the types passed in the `operator()`
* of the functor type `F`. Must contain no type whose
* scalar type is `fvar<T>`
* @param func Functor
* @param args... Parameter pack of arguments to be passed to functor.
*/
template <typename F, typename... TArgs,
require_all_not_st_fvar<TArgs...>* = nullptr>
inline auto finite_diff(const F& func, const TArgs&... args) {
return func(args...);
}

} // namespace math
} // namespace stan

#endif
101 changes: 101 additions & 0 deletions stan/math/fwd/functor/integrate_1d.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#ifndef STAN_MATH_FWD_FUNCTOR_INTEGRATE_1D_HPP
#define STAN_MATH_FWD_FUNCTOR_INTEGRATE_1D_HPP

#include <stan/math/fwd/meta.hpp>
#include <stan/math/prim/functor/integrate_1d.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/meta/forward_as.hpp>
#include <stan/math/prim/functor/apply.hpp>
#include <stan/math/fwd/functor/finite_diff.hpp>

namespace stan {
namespace math {
/**
* Return the integral of f from a to b to the given relative tolerance
*
* @tparam F Type of f
* @tparam T_a type of first limit
* @tparam T_b type of second limit
* @tparam Args types of parameter pack arguments
*
* @param f the functor to integrate
* @param a lower limit of integration
* @param b upper limit of integration
* @param relative_tolerance relative tolerance passed to Boost quadrature
* @param[in, out] msgs the print stream for warning messages
* @param args additional arguments to pass to f
* @return numeric integral of function f
*/
template <typename F, typename T_a, typename T_b, typename... Args,
require_any_st_fvar<T_a, T_b, Args...> * = nullptr>
inline return_type_t<T_a, T_b, Args...> integrate_1d_impl(
const F &f, const T_a &a, const T_b &b, double relative_tolerance,
std::ostream *msgs, const Args &... args) {
using FvarT = scalar_type_t<return_type_t<T_a, T_b, Args...>>;

// Wrap integrate_1d call in a functor where the input arguments are only
// for which tangents are needed
auto a_val = value_of(a);
auto b_val = value_of(b);
auto func
= [f, msgs, relative_tolerance, a_val, b_val](const auto &... args_var) {
return integrate_1d_impl(f, a_val, b_val, relative_tolerance, msgs,
args_var...);
};
FvarT ret = finite_diff(func, args...);

// Calculate tangents w.r.t. integration bounds if needed
if (is_fvar<T_a>::value || is_fvar<T_b>::value) {
auto val_args = std::make_tuple(value_of(args)...);
if (is_fvar<T_a>::value) {
ret.d_ += math::forward_as<FvarT>(a).d_
* math::apply(
[&](auto &&... tuple_args) {
return -f(a_val, 0.0, msgs, tuple_args...);
},
val_args);
}
if (is_fvar<T_b>::value) {
ret.d_ += math::forward_as<FvarT>(b).d_
* math::apply(
[&](auto &&... tuple_args) {
return f(b_val, 0.0, msgs, tuple_args...);
},
val_args);
}
}
return ret;
}

/**
* Compute the integral of the single variable function f from a to b to within
* a specified relative tolerance. a and b can be finite or infinite.
*
* @tparam T_a type of first limit
* @tparam T_b type of second limit
* @tparam T_theta type of parameters
* @tparam T Type of f
*
* @param f the functor to integrate
* @param a lower limit of integration
* @param b upper limit of integration
* @param theta additional parameters to be passed to f
* @param x_r additional data to be passed to f
* @param x_i additional integer data to be passed to f
* @param[in, out] msgs the print stream for warning messages
* @param relative_tolerance relative tolerance passed to Boost quadrature
* @return numeric integral of function f
*/
template <typename F, typename T_a, typename T_b, typename T_theta,
require_any_fvar_t<T_a, T_b, T_theta> * = nullptr>
inline return_type_t<T_a, T_b, T_theta> integrate_1d(
const F &f, const T_a &a, const T_b &b, const std::vector<T_theta> &theta,
const std::vector<double> &x_r, const std::vector<int> &x_i,
std::ostream *msgs, const double relative_tolerance) {
return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b, relative_tolerance,
msgs, theta, x_r, x_i);
}

} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@
#include <stan/math/prim/fun/scaled_add.hpp>
#include <stan/math/prim/fun/sd.hpp>
#include <stan/math/prim/fun/segment.hpp>
#include <stan/math/prim/fun/serializer.hpp>
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/prim/fun/sign.hpp>
#include <stan/math/prim/fun/signbit.hpp>
Expand Down
31 changes: 15 additions & 16 deletions test/unit/math/serializer.hpp → stan/math/prim/fun/serializer.hpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#ifndef TEST_UNIT_MATH_SERIALIZER_HPP
#define TEST_UNIT_MATH_SERIALIZER_HPP
#ifndef STAN_MATH_PRIM_FUN_SERIALIZER_HPP
#define STAN_MATH_PRIM_FUN_SERIALIZER_HPP

#include <stan/math.hpp>
#include <stan/math/prim/meta/promote_scalar_type.hpp>
#include <stan/math/prim/fun/to_vector.hpp>
#include <stan/math/prim/fun/to_array_1d.hpp>
#include <complex>
#include <string>
#include <vector>

namespace stan {
namespace test {
namespace math {

/**
* A class to store a sequence of values which can be deserialized
Expand Down Expand Up @@ -44,10 +46,10 @@ struct deserializer {
/**
* Construct a deserializer from the specified sequence of values.
*
* @param vals values to deserialize
* @param v_vals values to deserialize
*/
explicit deserializer(const Eigen::Matrix<T, -1, 1>& v_vals)
: position_(0), vals_(math::to_array_1d(v_vals)) {}
: position_(0), vals_(to_array_1d(v_vals)) {}

/**
* Read a scalar conforming to the shape of the specified argument,
Expand Down Expand Up @@ -94,8 +96,8 @@ struct deserializer {
*/
template <typename U, require_std_vector_t<U>* = nullptr,
require_not_st_complex<U>* = nullptr>
typename stan::math::promote_scalar_type<T, U>::type read(const U& x) {
typename stan::math::promote_scalar_type<T, U>::type y;
promote_scalar_t<T, U> read(const U& x) {
promote_scalar_t<T, U> y;
y.reserve(x.size());
for (size_t i = 0; i < x.size(); ++i)
y.push_back(read(x[i]));
Expand All @@ -113,9 +115,8 @@ struct deserializer {
* @return deserialized value with shape and size matching argument
*/
template <typename U, require_std_vector_st<is_complex, U>* = nullptr>
typename stan::math::promote_scalar_type<std::complex<T>, U>::type read(
const U& x) {
typename stan::math::promote_scalar_type<std::complex<T>, U>::type y;
promote_scalar_t<std::complex<T>, U> read(const U& x) {
promote_scalar_t<std::complex<T>, U> y;
y.reserve(x.size());
for (size_t i = 0; i < x.size(); ++i)
y.push_back(read(x[i]));
Expand Down Expand Up @@ -257,9 +258,7 @@ struct serializer {
*
* @return serialized values
*/
const Eigen::Matrix<T, -1, 1>& vector_vals() {
return math::to_vector(vals_);
}
const Eigen::Matrix<T, -1, 1>& vector_vals() { return to_vector(vals_); }
};

/**
Expand Down Expand Up @@ -338,10 +337,10 @@ std::vector<real_return_t<T>> serialize_return(const T& x) {
*/
template <typename... Ts>
Eigen::VectorXd serialize_args(const Ts... xs) {
return math::to_vector(serialize<double>(xs...));
return to_vector(serialize<double>(xs...));
}

} // namespace test
} // namespace math
} // namespace stan

#endif
29 changes: 15 additions & 14 deletions stan/math/prim/functor/finite_diff_gradient_auto.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,39 +46,40 @@ namespace math {
* @param[out] fx function applied to argument
* @param[out] grad_fx gradient of function at argument
*/
template <typename F>
void finite_diff_gradient_auto(const F& f, const Eigen::VectorXd& x, double& fx,
Eigen::VectorXd& grad_fx) {
Eigen::VectorXd x_temp(x);
template <typename F, typename VectorT,
typename ScalarT = return_type_t<VectorT>>
void finite_diff_gradient_auto(const F& f, const VectorT& x, ScalarT& fx,
VectorT& grad_fx) {
VectorT x_temp(x);
fx = f(x);
grad_fx.resize(x.size());
for (int i = 0; i < x.size(); ++i) {
double h = finite_diff_stepsize(x(i));
double h = finite_diff_stepsize(value_of_rec(x[i]));

double delta_f = 0;
ScalarT delta_f = 0;

x_temp(i) = x(i) + 3 * h;
x_temp[i] = x[i] + 3 * h;
delta_f += f(x_temp);

x_temp(i) = x(i) + 2 * h;
x_temp[i] = x[i] + 2 * h;
delta_f -= 9 * f(x_temp);

x_temp(i) = x(i) + h;
x_temp[i] = x[i] + h;
delta_f += 45 * f(x_temp);

x_temp(i) = x(i) + -3 * h;
x_temp[i] = x[i] + -3 * h;
delta_f -= f(x_temp);

x_temp(i) = x(i) + -2 * h;
x_temp[i] = x[i] + -2 * h;
delta_f += 9 * f(x_temp);

x_temp(i) = x(i) - h;
x_temp[i] = x[i] - h;
delta_f -= 45 * f(x_temp);

delta_f /= 60 * h;

x_temp(i) = x(i);
grad_fx(i) = delta_f;
x_temp[i] = x[i];
grad_fx[i] = delta_f;
}
}

Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/functor/integrate_1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ inline double integrate(const F& f, double a, double b,
* @return numeric integral of function f
*/
template <typename F, typename... Args,
require_all_not_st_var<Args...>* = nullptr>
require_all_st_arithmetic<Args...>* = nullptr>
inline double integrate_1d_impl(const F& f, double a, double b,
double relative_tolerance, std::ostream* msgs,
const Args&... args) {
Expand Down
Loading

0 comments on commit efbc688

Please sign in to comment.