-
-
Notifications
You must be signed in to change notification settings - Fork 188
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2929 from stan-dev/fvar-support
Framework for generic fvar<T> support through finite-differences
- Loading branch information
Showing
11 changed files
with
291 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#ifndef STAN_MATH_FWD_FUNCTOR_FINITE_DIFF_HPP | ||
#define STAN_MATH_FWD_FUNCTOR_FINITE_DIFF_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/functor/apply_scalar_binary.hpp> | ||
#include <stan/math/prim/functor/finite_diff_gradient_auto.hpp> | ||
#include <stan/math/prim/fun/value_of.hpp> | ||
#include <stan/math/prim/fun/sum.hpp> | ||
#include <stan/math/prim/fun/serializer.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
namespace internal { | ||
/** | ||
* Helper function for aggregating tangents if the respective input argument | ||
* was an fvar<T> type. | ||
* | ||
* Overload for when the input is not an fvar<T> and no tangents are needed. | ||
* | ||
* @tparam FuncTangent Type of tangent calculated by finite-differences | ||
* @tparam InputArg Type of the function input argument | ||
* @param tangent Calculated tangent | ||
* @param arg Input argument | ||
*/ | ||
template <typename FuncTangent, typename InputArg, | ||
require_not_st_fvar<InputArg>* = nullptr> | ||
inline constexpr double aggregate_tangent(const FuncTangent& tangent, | ||
const InputArg& arg) { | ||
return 0; | ||
} | ||
|
||
/** | ||
* Helper function for aggregating tangents if the respective input argument | ||
* was an fvar<T> type. | ||
* | ||
* Overload for when the input is an fvar<T> and its tangent needs to be | ||
* aggregated. | ||
* | ||
* @tparam FuncTangent Type of tangent calculated by finite-differences | ||
* @tparam InputArg Type of the function input argument | ||
* @param tangent Calculated tangent | ||
* @param arg Input argument | ||
*/ | ||
template <typename FuncTangent, typename InputArg, | ||
require_st_fvar<InputArg>* = nullptr> | ||
inline auto aggregate_tangent(const FuncTangent& tangent, const InputArg& arg) { | ||
return sum(apply_scalar_binary( | ||
tangent, arg, [](const auto& x, const auto& y) { return x * y.d_; })); | ||
} | ||
} // namespace internal | ||
|
||
/** | ||
* Construct an fvar<T> where the tangent is calculated by finite-differencing. | ||
* Finite-differencing is only perfomed where the scalar type to be evaluated is | ||
* `fvar<T>. | ||
* | ||
* Higher-order inputs (i.e., fvar<var> & fvar<fvar<T>>) are also implicitly | ||
* supported through auto-diffing the finite-differencing process. | ||
* | ||
* @tparam F Type of functor for which fvar<T> support is needed | ||
* @tparam TArgs Template parameter pack of the types passed in the `operator()` | ||
* of the functor type `F`. Must contain at least on type whose | ||
* scalar type is `fvar<T>` | ||
* @param func Functor for which fvar<T> support is needed | ||
* @param args Parameter pack of arguments to be passed to functor. | ||
*/ | ||
template <typename F, typename... TArgs, | ||
require_any_st_fvar<TArgs...>* = nullptr> | ||
inline auto finite_diff(const F& func, const TArgs&... args) { | ||
using FvarT = return_type_t<TArgs...>; | ||
using FvarInnerT = typename FvarT::Scalar; | ||
|
||
std::vector<FvarInnerT> serialised_args | ||
= serialize<FvarInnerT>(value_of(args)...); | ||
|
||
auto serial_functor = [&](const auto& v) { | ||
auto v_deserializer = to_deserializer(v); | ||
return func(v_deserializer.read(args)...); | ||
}; | ||
|
||
FvarInnerT rtn_value; | ||
std::vector<FvarInnerT> grad; | ||
finite_diff_gradient_auto(serial_functor, serialised_args, rtn_value, grad); | ||
|
||
FvarInnerT rtn_grad = 0; | ||
auto grad_deserializer = to_deserializer(grad); | ||
// Use a fold-expression to aggregate tangents for input arguments | ||
static_cast<void>( | ||
std::initializer_list<int>{(rtn_grad += internal::aggregate_tangent( | ||
grad_deserializer.read(args), args), | ||
0)...}); | ||
|
||
return FvarT(rtn_value, rtn_grad); | ||
} | ||
|
||
/** | ||
* Construct an fvar<T> where the tangent is calculated by finite-differencing. | ||
* Finite-differencing is only perfomed where the scalar type to be evaluated is | ||
* `fvar<T>. | ||
* | ||
* This overload is used when no fvar<T> arguments are passed and simply | ||
* evaluates the functor with the provided arguments. | ||
* | ||
* @tparam F Type of functor | ||
* @tparam TArgs Template parameter pack of the types passed in the `operator()` | ||
* of the functor type `F`. Must contain no type whose | ||
* scalar type is `fvar<T>` | ||
* @param func Functor | ||
* @param args... Parameter pack of arguments to be passed to functor. | ||
*/ | ||
template <typename F, typename... TArgs, | ||
require_all_not_st_fvar<TArgs...>* = nullptr> | ||
inline auto finite_diff(const F& func, const TArgs&... args) { | ||
return func(args...); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
#ifndef STAN_MATH_FWD_FUNCTOR_INTEGRATE_1D_HPP | ||
#define STAN_MATH_FWD_FUNCTOR_INTEGRATE_1D_HPP | ||
|
||
#include <stan/math/fwd/meta.hpp> | ||
#include <stan/math/prim/functor/integrate_1d.hpp> | ||
#include <stan/math/prim/fun/value_of.hpp> | ||
#include <stan/math/prim/meta/forward_as.hpp> | ||
#include <stan/math/prim/functor/apply.hpp> | ||
#include <stan/math/fwd/functor/finite_diff.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
/** | ||
* Return the integral of f from a to b to the given relative tolerance | ||
* | ||
* @tparam F Type of f | ||
* @tparam T_a type of first limit | ||
* @tparam T_b type of second limit | ||
* @tparam Args types of parameter pack arguments | ||
* | ||
* @param f the functor to integrate | ||
* @param a lower limit of integration | ||
* @param b upper limit of integration | ||
* @param relative_tolerance relative tolerance passed to Boost quadrature | ||
* @param[in, out] msgs the print stream for warning messages | ||
* @param args additional arguments to pass to f | ||
* @return numeric integral of function f | ||
*/ | ||
template <typename F, typename T_a, typename T_b, typename... Args, | ||
require_any_st_fvar<T_a, T_b, Args...> * = nullptr> | ||
inline return_type_t<T_a, T_b, Args...> integrate_1d_impl( | ||
const F &f, const T_a &a, const T_b &b, double relative_tolerance, | ||
std::ostream *msgs, const Args &... args) { | ||
using FvarT = scalar_type_t<return_type_t<T_a, T_b, Args...>>; | ||
|
||
// Wrap integrate_1d call in a functor where the input arguments are only | ||
// for which tangents are needed | ||
auto a_val = value_of(a); | ||
auto b_val = value_of(b); | ||
auto func | ||
= [f, msgs, relative_tolerance, a_val, b_val](const auto &... args_var) { | ||
return integrate_1d_impl(f, a_val, b_val, relative_tolerance, msgs, | ||
args_var...); | ||
}; | ||
FvarT ret = finite_diff(func, args...); | ||
|
||
// Calculate tangents w.r.t. integration bounds if needed | ||
if (is_fvar<T_a>::value || is_fvar<T_b>::value) { | ||
auto val_args = std::make_tuple(value_of(args)...); | ||
if (is_fvar<T_a>::value) { | ||
ret.d_ += math::forward_as<FvarT>(a).d_ | ||
* math::apply( | ||
[&](auto &&... tuple_args) { | ||
return -f(a_val, 0.0, msgs, tuple_args...); | ||
}, | ||
val_args); | ||
} | ||
if (is_fvar<T_b>::value) { | ||
ret.d_ += math::forward_as<FvarT>(b).d_ | ||
* math::apply( | ||
[&](auto &&... tuple_args) { | ||
return f(b_val, 0.0, msgs, tuple_args...); | ||
}, | ||
val_args); | ||
} | ||
} | ||
return ret; | ||
} | ||
|
||
/** | ||
* Compute the integral of the single variable function f from a to b to within | ||
* a specified relative tolerance. a and b can be finite or infinite. | ||
* | ||
* @tparam T_a type of first limit | ||
* @tparam T_b type of second limit | ||
* @tparam T_theta type of parameters | ||
* @tparam T Type of f | ||
* | ||
* @param f the functor to integrate | ||
* @param a lower limit of integration | ||
* @param b upper limit of integration | ||
* @param theta additional parameters to be passed to f | ||
* @param x_r additional data to be passed to f | ||
* @param x_i additional integer data to be passed to f | ||
* @param[in, out] msgs the print stream for warning messages | ||
* @param relative_tolerance relative tolerance passed to Boost quadrature | ||
* @return numeric integral of function f | ||
*/ | ||
template <typename F, typename T_a, typename T_b, typename T_theta, | ||
require_any_fvar_t<T_a, T_b, T_theta> * = nullptr> | ||
inline return_type_t<T_a, T_b, T_theta> integrate_1d( | ||
const F &f, const T_a &a, const T_b &b, const std::vector<T_theta> &theta, | ||
const std::vector<double> &x_r, const std::vector<int> &x_i, | ||
std::ostream *msgs, const double relative_tolerance) { | ||
return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b, relative_tolerance, | ||
msgs, theta, x_r, x_i); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.