Skip to content

Commit

Permalink
Bug Fix: Inaccurate Tails of Incomplete Gamma (#481)
Browse files Browse the repository at this point in the history
  • Loading branch information
akleeman authored Apr 29, 2024
1 parent a5b5a21 commit 6725dc5
Show file tree
Hide file tree
Showing 6 changed files with 5,781 additions and 30 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ swift_cc_test(
"tests/test_block_utils.cc",
"tests/test_call_trace.cc",
"tests/test_callers.cc",
"tests/test_chi_squared_versus_gsl.cc",
"tests/test_compression.cc",
"tests/test_concatenate.cc",
"tests/test_conditional_gaussian.cc",
Expand Down
1 change: 1 addition & 0 deletions include/albatross/Stats
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef ALBATROSS_STATS_H
#define ALBATROSS_STATS_H

#include <albatross/src/details/typecast.hpp>
#include <albatross/src/stats/gaussian.hpp>
#include <albatross/src/stats/gauss_legendre.hpp>
#include <albatross/src/stats/incomplete_gamma.hpp>
Expand Down
17 changes: 13 additions & 4 deletions include/albatross/src/stats/chi_squared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ namespace albatross {

namespace details {

inline double chi_squared_cdf_unsafe(double x, std::size_t degrees_of_freedom) {
return incomplete_gamma(0.5 * cast::to_double(degrees_of_freedom), 0.5 * x);
inline double chi_squared_cdf_unsafe(double x, double degrees_of_freedom) {
return incomplete_gamma(0.5 * degrees_of_freedom, 0.5 * x);
}

inline double chi_squared_cdf_safe(double x, std::size_t degrees_of_freedom) {
inline double chi_squared_cdf_safe(double x, double degrees_of_freedom) {

if (std::isnan(x) || x < 0.) {
return NAN;
Expand All @@ -53,7 +53,16 @@ inline double chi_squared_cdf_safe(double x, std::size_t degrees_of_freedom) {

} // namespace details

inline double chi_squared_cdf(double x, std::size_t degrees_of_freedom) {
template <typename IntType,
typename = std::enable_if_t<std::is_integral<IntType>::value>>
inline double chi_squared_cdf(double x, IntType degrees_of_freedom) {
// due to implicit argument conversions we can't directly use cast::to_double
// here.
return details::chi_squared_cdf_safe(x,
static_cast<double>(degrees_of_freedom));
}

inline double chi_squared_cdf(double x, double degrees_of_freedom) {
return details::chi_squared_cdf_safe(x, degrees_of_freedom);
}

Expand Down
31 changes: 5 additions & 26 deletions include/albatross/src/stats/incomplete_gamma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,32 +75,11 @@ inline double incomplete_gamma_quadrature_recursive(double lb, double ub,

inline std::pair<double, double> incomplete_gamma_quadrature_bounds(double a,
double z) {

if (a > 800) {
return std::make_pair(std::max(0., std::min(z, a) - 11 * sqrt(a)),
std::min(z, a + 10 * sqrt(a)));
} else if (a > 300) {
return std::make_pair(std::max(0., std::min(z, a) - 10 * sqrt(a)),
std::min(z, a + 9 * sqrt(a)));
} else if (a > 90) {
return std::make_pair(std::max(0., std::min(z, a) - 9 * sqrt(a)),
std::min(z, a + 8 * sqrt(a)));
} else if (a > 70) {
return std::make_pair(std::max(0., std::min(z, a) - 8 * sqrt(a)),
std::min(z, a + 7 * sqrt(a)));
} else if (a > 50) {
return std::make_pair(std::max(0., std::min(z, a) - 7 * sqrt(a)),
std::min(z, a + 6 * sqrt(a)));
} else if (a > 40) {
return std::make_pair(std::max(0., std::min(z, a) - 6 * sqrt(a)),
std::min(z, a + 5 * sqrt(a)));
} else if (a > 30) {
return std::make_pair(std::max(0., std::min(z, a) - 5 * sqrt(a)),
std::min(z, a + 4 * sqrt(a)));
} else {
return std::make_pair(std::max(0., std::min(z, a) - 4 * sqrt(a)),
std::min(z, a + 4 * sqrt(a)));
}
// NOTE: GCEM uses a large conditional block to select tighter bounds, but in
// practice those bounds were not tight enough, particularly on the upper
// bound, so we've modified this function to be more conservative
return std::make_pair(std::max(0., std::min(z, a) - 12 * sqrt(a)),
std::min(z, a + 13 * sqrt(a + 1)));
}

inline double incomplete_gamma_quadrature(double a, double z) {
Expand Down
Loading

0 comments on commit 6725dc5

Please sign in to comment.