Skip to content

Commit

Permalink
Merge pull request #3114 from lingium/feature/issue-3113-beta-neg-bin…
Browse files Browse the repository at this point in the history
…omial-lccdf

add beta_neg_binomial_lccdf
  • Loading branch information
lingium authored Oct 27, 2024
2 parents 2fdd3ed + f2bebaf commit 65fc8e6
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 83 deletions.
135 changes: 80 additions & 55 deletions stan/math/prim/fun/grad_F32.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,20 @@ namespace math {
* This power-series representation converges for all gradients
* under the same conditions as the 3F2 function itself.
*
* @tparam T type of arguments and result
* @tparam grad_a1 boolean indicating if gradient with respect to a1 is required
* @tparam grad_a2 boolean indicating if gradient with respect to a2 is required
* @tparam grad_a3 boolean indicating if gradient with respect to a3 is required
* @tparam grad_b1 boolean indicating if gradient with respect to b1 is required
* @tparam grad_b2 boolean indicating if gradient with respect to b2 is required
* @tparam grad_z boolean indicating if gradient with respect to z is required
* @tparam T1 a scalar type
* @tparam T2 a scalar type
* @tparam T3 a scalar type
* @tparam T4 a scalar type
* @tparam T5 a scalar type
* @tparam T6 a scalar type
* @tparam T7 a scalar type
* @tparam T8 a scalar type
* @param[out] g g pointer to array of six values of type T, result.
* @param[in] a1 a1 see generalized hypergeometric function definition.
* @param[in] a2 a2 see generalized hypergeometric function definition.
Expand All @@ -35,84 +48,96 @@ namespace math {
* @param[in] precision precision of the infinite sum
* @param[in] max_steps number of steps to take
*/
template <typename T>
void grad_F32(T* g, const T& a1, const T& a2, const T& a3, const T& b1,
const T& b2, const T& z, const T& precision = 1e-6,
template <bool grad_a1 = true, bool grad_a2 = true, bool grad_a3 = true,
bool grad_b1 = true, bool grad_b2 = true, bool grad_z = true,
typename T1, typename T2, typename T3, typename T4, typename T5,
typename T6, typename T7, typename T8 = double>
void grad_F32(T1* g, const T2& a1, const T3& a2, const T4& a3, const T5& b1,
const T6& b2, const T7& z, const T8& precision = 1e-6,
int max_steps = 1e5) {
check_3F2_converges("grad_F32", a1, a2, a3, b1, b2, z);

using std::exp;
using std::fabs;
using std::log;

for (int i = 0; i < 6; ++i) {
g[i] = 0.0;
}

T log_g_old[6];
T1 log_g_old[6];
for (auto& x : log_g_old) {
x = NEGATIVE_INFTY;
}

T log_t_old = 0.0;
T log_t_new = 0.0;
T1 log_t_old = 0.0;
T1 log_t_new = 0.0;

T log_z = log(z);
T7 log_z = log(z);

double log_t_new_sign = 1.0;
double log_t_old_sign = 1.0;
double log_g_old_sign[6];
T1 log_t_new_sign = 1.0;
T1 log_t_old_sign = 1.0;
T1 log_g_old_sign[6];
for (int i = 0; i < 6; ++i) {
log_g_old_sign[i] = 1.0;
}

std::array<T1, 6> term{0};
for (int k = 0; k <= max_steps; ++k) {
T p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
T1 p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
if (p == 0) {
return;
}

log_t_new += log(fabs(p)) + log_z;
log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
if constexpr (grad_a1) {
term[0]
= log_g_old_sign[0] * log_t_old_sign * exp(log_g_old[0] - log_t_old)
+ inv(a1 + k);
log_g_old[0] = log_t_new + log(fabs(term[0]));
log_g_old_sign[0] = term[0] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[0] += log_g_old_sign[0] * exp(log_g_old[0]);
}

if constexpr (grad_a2) {
term[1]
= log_g_old_sign[1] * log_t_old_sign * exp(log_g_old[1] - log_t_old)
+ inv(a2 + k);
log_g_old[1] = log_t_new + log(fabs(term[1]));
log_g_old_sign[1] = term[1] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[1] += log_g_old_sign[1] * exp(log_g_old[1]);
}

if constexpr (grad_a3) {
term[2]
= log_g_old_sign[2] * log_t_old_sign * exp(log_g_old[2] - log_t_old)
+ inv(a3 + k);
log_g_old[2] = log_t_new + log(fabs(term[2]));
log_g_old_sign[2] = term[2] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[2] += log_g_old_sign[2] * exp(log_g_old[2]);
}

if constexpr (grad_b1) {
term[3]
= log_g_old_sign[3] * log_t_old_sign * exp(log_g_old[3] - log_t_old)
- inv(b1 + k);
log_g_old[3] = log_t_new + log(fabs(term[3]));
log_g_old_sign[3] = term[3] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[3] += log_g_old_sign[3] * exp(log_g_old[3]);
}

if constexpr (grad_b2) {
term[4]
= log_g_old_sign[4] * log_t_old_sign * exp(log_g_old[4] - log_t_old)
- inv(b2 + k);
log_g_old[4] = log_t_new + log(fabs(term[4]));
log_g_old_sign[4] = term[4] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[4] += log_g_old_sign[4] * exp(log_g_old[4]);
}

// g_old[0] = t_new * (g_old[0] / t_old + 1.0 / (a1 + k));
T term = log_g_old_sign[0] * log_t_old_sign * exp(log_g_old[0] - log_t_old)
+ inv(a1 + k);
log_g_old[0] = log_t_new + log(fabs(term));
log_g_old_sign[0] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[1] = t_new * (g_old[1] / t_old + 1.0 / (a2 + k));
term = log_g_old_sign[1] * log_t_old_sign * exp(log_g_old[1] - log_t_old)
+ inv(a2 + k);
log_g_old[1] = log_t_new + log(fabs(term));
log_g_old_sign[1] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[2] = t_new * (g_old[2] / t_old + 1.0 / (a3 + k));
term = log_g_old_sign[2] * log_t_old_sign * exp(log_g_old[2] - log_t_old)
+ inv(a3 + k);
log_g_old[2] = log_t_new + log(fabs(term));
log_g_old_sign[2] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[3] = t_new * (g_old[3] / t_old - 1.0 / (b1 + k));
term = log_g_old_sign[3] * log_t_old_sign * exp(log_g_old[3] - log_t_old)
- inv(b1 + k);
log_g_old[3] = log_t_new + log(fabs(term));
log_g_old_sign[3] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[4] = t_new * (g_old[4] / t_old - 1.0 / (b2 + k));
term = log_g_old_sign[4] * log_t_old_sign * exp(log_g_old[4] - log_t_old)
- inv(b2 + k);
log_g_old[4] = log_t_new + log(fabs(term));
log_g_old_sign[4] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[5] = t_new * (g_old[5] / t_old + 1.0 / z);
term = log_g_old_sign[5] * log_t_old_sign * exp(log_g_old[5] - log_t_old)
+ inv(z);
log_g_old[5] = log_t_new + log(fabs(term));
log_g_old_sign[5] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

for (int i = 0; i < 6; ++i) {
g[i] += log_g_old_sign[i] * exp(log_g_old[i]);
if constexpr (grad_z) {
term[5]
= log_g_old_sign[5] * log_t_old_sign * exp(log_g_old[5] - log_t_old)
+ inv(z);
log_g_old[5] = log_t_new + log(fabs(term[5]));
log_g_old_sign[5] = term[5] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[5] += log_g_old_sign[5] * exp(log_g_old[5]);
}

if (log_t_new <= log(precision)) {
Expand Down
9 changes: 5 additions & 4 deletions stan/math/prim/fun/grad_pFq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ template <bool calc_a = true, bool calc_b = true, bool calc_z = true,
typename T_Rtn = return_type_t<Ta, Tb, Tz>,
typename Ta_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Ta>>,
typename Tb_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Tb>>>
std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> grad_pFq(const TpFq& pfq_val, const Ta& a,
const Tb& b, const Tz& z,
double precision = 1e-14,
int max_steps = 1e6) {
inline std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> grad_pFq(const TpFq& pfq_val,
const Ta& a, const Tb& b,
const Tz& z,
double precision = 1e-14,
int max_steps = 1e6) {
using std::max;
using Ta_Array = Eigen::Array<return_type_t<Ta>, -1, 1>;
using Tb_Array = Eigen::Array<return_type_t<Tb>, -1, 1>;
Expand Down
47 changes: 23 additions & 24 deletions stan/math/prim/fun/hypergeometric_3F2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,34 @@ namespace stan {
namespace math {
namespace internal {
template <typename Ta, typename Tb, typename Tz,
typename T_return = return_type_t<Ta, Tb, Tz>,
typename ArrayAT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
typename ArrayBT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
require_all_vector_t<Ta, Tb>* = nullptr,
require_stan_scalar_t<Tz>* = nullptr>
T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
double precision = 1e-6,
int max_steps = 1e5) {
ArrayAT a_array = as_array_or_scalar(a);
ArrayBT b_array = append_row(as_array_or_scalar(b), 1.0);
inline return_type_t<Ta, Tb, Tz> hypergeometric_3F2_infsum(
const Ta& a, const Tb& b, const Tz& z, double precision = 1e-6,
int max_steps = 1e5) {
using T_return = return_type_t<Ta, Tb, Tz>;
Eigen::Array<scalar_type_t<Ta>, 3, 1> a_array = as_array_or_scalar(a);
Eigen::Array<scalar_type_t<Tb>, 3, 1> b_array
= append_row(as_array_or_scalar(b), 1.0);
check_3F2_converges("hypergeometric_3F2", a_array[0], a_array[1], a_array[2],
b_array[0], b_array[1], z);

T_return t_acc = 1.0;
T_return log_t = 0.0;
T_return log_z = log(fabs(z));
Eigen::ArrayXi a_signs = sign(value_of_rec(a_array));
Eigen::ArrayXi b_signs = sign(value_of_rec(b_array));
plain_type_t<decltype(a_array)> apk = a_array;
plain_type_t<decltype(b_array)> bpk = b_array;
auto log_z = log(fabs(z));
Eigen::Array<int, 3, 1> a_signs = sign(value_of_rec(a_array));
Eigen::Array<int, 3, 1> b_signs = sign(value_of_rec(b_array));
int z_sign = sign(value_of_rec(z));
int t_sign = z_sign * a_signs.prod() * b_signs.prod();

int k = 0;
while (k <= max_steps && log_t >= log(precision)) {
const double log_precision = log(precision);
while (k <= max_steps && log_t >= log_precision) {
// Replace zero values with 1 prior to taking the log so that we accumulate
// 0.0 rather than -inf
const auto& abs_apk = math::fabs((apk == 0).select(1.0, apk));
const auto& abs_bpk = math::fabs((bpk == 0).select(1.0, bpk));
T_return p = sum(log(abs_apk)) - sum(log(abs_bpk));
const auto& abs_apk = math::fabs((a_array == 0).select(1.0, a_array));
const auto& abs_bpk = math::fabs((b_array == 0).select(1.0, b_array));
auto p = sum(log(abs_apk)) - sum(log(abs_bpk));
if (p == NEGATIVE_INFTY) {
return t_acc;
}
Expand All @@ -59,10 +57,10 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
"overflow hypergeometric function did not converge.");
}
k++;
apk.array() += 1.0;
bpk.array() += 1.0;
a_signs = sign(value_of_rec(apk));
b_signs = sign(value_of_rec(bpk));
a_array += 1.0;
b_array += 1.0;
a_signs = sign(value_of_rec(a_array));
b_signs = sign(value_of_rec(b_array));
t_sign = a_signs.prod() * b_signs.prod() * t_sign;
}
if (k == max_steps) {
Expand Down Expand Up @@ -115,7 +113,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
template <typename Ta, typename Tb, typename Tz,
require_all_vector_t<Ta, Tb>* = nullptr,
require_stan_scalar_t<Tz>* = nullptr>
auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
inline auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
check_3F2_converges("hypergeometric_3F2", a[0], a[1], a[2], b[0], b[1], z);
// Boost's pFq throws convergence errors in some cases, fallback to naive
// infinite-sum approach (tests pass for these)
Expand Down Expand Up @@ -143,8 +141,9 @@ auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
*/
template <typename Ta, typename Tb, typename Tz,
require_all_stan_scalar_t<Ta, Tb, Tz>* = nullptr>
auto hypergeometric_3F2(const std::initializer_list<Ta>& a,
const std::initializer_list<Tb>& b, const Tz& z) {
inline auto hypergeometric_3F2(const std::initializer_list<Ta>& a,
const std::initializer_list<Tb>& b,
const Tz& z) {
return hypergeometric_3F2(std::vector<Ta>(a), std::vector<Tb>(b), z);
}

Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/prob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <stan/math/prim/prob/beta_lccdf.hpp>
#include <stan/math/prim/prob/beta_lcdf.hpp>
#include <stan/math/prim/prob/beta_lpdf.hpp>
#include <stan/math/prim/prob/beta_neg_binomial_lccdf.hpp>
#include <stan/math/prim/prob/beta_neg_binomial_lpmf.hpp>
#include <stan/math/prim/prob/beta_proportion_ccdf_log.hpp>
#include <stan/math/prim/prob/beta_proportion_cdf_log.hpp>
Expand Down
Loading

0 comments on commit 65fc8e6

Please sign in to comment.