-
-
Notifications
You must be signed in to change notification settings - Fork 369
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3313 from stan-dev/feature/3299-chainset
Feature/3299 chainset
- Loading branch information
Showing
34 changed files
with
10,345 additions
and
2,483 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#ifndef STAN_ANALYZE_MCMC_ESS_HPP | ||
#define STAN_ANALYZE_MCMC_ESS_HPP | ||
|
||
#include <stan/math/prim.hpp> | ||
#include <stan/analyze/mcmc/autocovariance.hpp> | ||
#include <algorithm> | ||
#include <cmath> | ||
#include <vector> | ||
#include <limits> | ||
|
||
namespace stan { | ||
namespace analyze { | ||
|
||
/** | ||
* Computes the effective sample size (ESS) for the specified | ||
* parameter across all chains. The number of draws per chain must be > 3, | ||
* and the values across all draws must be finite and not constant. | ||
* See https://arxiv.org/abs/1903.08008, section 3.2 for discussion. | ||
* | ||
* Sample autocovariance is computed using the implementation in this namespace | ||
* which normalizes lag-k autocorrelation estimators by N instead of (N - k), | ||
* yielding biased but more stable estimators as discussed in Geyer (1992); see | ||
* https://projecteuclid.org/euclid.ss/1177011137. | ||
* | ||
* @param chains matrix of draws across all chains | ||
* @return effective sample size for the specified parameter | ||
*/ | ||
double ess(const Eigen::MatrixXd& chains) { | ||
const Eigen::Index num_chains = chains.cols(); | ||
const Eigen::Index draws_per_chain = chains.rows(); | ||
Eigen::MatrixXd acov(draws_per_chain, num_chains); | ||
Eigen::VectorXd chain_mean(num_chains); | ||
Eigen::VectorXd chain_var(num_chains); | ||
|
||
// compute the per-chain autocovariance | ||
for (size_t i = 0; i < num_chains; ++i) { | ||
chain_mean(i) = chains.col(i).mean(); | ||
Eigen::Map<const Eigen::VectorXd> draw_col(chains.col(i).data(), | ||
draws_per_chain); | ||
Eigen::VectorXd cov_col(draws_per_chain); | ||
autocovariance<double>(draw_col, cov_col); | ||
acov.col(i) = cov_col; | ||
chain_var(i) = cov_col(0) * draws_per_chain / (draws_per_chain - 1); | ||
} | ||
|
||
// compute var_plus, eqn (3) | ||
double w_chain_var = math::mean(chain_var); // W (within chain var) | ||
double var_plus | ||
= w_chain_var * (draws_per_chain - 1) / draws_per_chain; // \hat{var}^{+} | ||
if (num_chains > 1) { | ||
var_plus += math::variance(chain_mean); // B (between chain var) | ||
} | ||
|
||
// Geyer's initial positive sequence, eqn (11) | ||
Eigen::VectorXd rho_hat_t = Eigen::VectorXd::Zero(draws_per_chain); | ||
double rho_hat_even = 1.0; | ||
rho_hat_t(0) = rho_hat_even; // lag 0 | ||
|
||
Eigen::VectorXd acov_t(num_chains); | ||
for (size_t i = 0; i < num_chains; ++i) { | ||
acov_t(i) = acov(1, i); | ||
} | ||
double rho_hat_odd = 1 - (w_chain_var - acov_t.mean()) / var_plus; | ||
rho_hat_t(1) = rho_hat_odd; // lag 1 | ||
|
||
// compute autocorrelation at lag t for pair (t, t+1) | ||
// paired autocorrelation is guaranteed to be positive, monotone and convex | ||
size_t t = 1; | ||
while (t < draws_per_chain - 4 && (rho_hat_even + rho_hat_odd > 0) | ||
&& !std::isnan(rho_hat_even + rho_hat_odd)) { | ||
for (size_t i = 0; i < num_chains; ++i) { | ||
acov_t(i) = acov.col(i)(t + 1); | ||
} | ||
rho_hat_even = 1 - (w_chain_var - acov_t.mean()) / var_plus; | ||
for (size_t i = 0; i < num_chains; ++i) { | ||
acov_t(i) = acov.col(i)(t + 2); | ||
} | ||
rho_hat_odd = 1 - (w_chain_var - acov_t.mean()) / var_plus; | ||
if ((rho_hat_even + rho_hat_odd) >= 0) { | ||
rho_hat_t(t + 1) = rho_hat_even; | ||
rho_hat_t(t + 2) = rho_hat_odd; | ||
} | ||
// convert initial positive sequence into an initial monotone sequence | ||
if (rho_hat_t(t + 1) + rho_hat_t(t + 2) > rho_hat_t(t - 1) + rho_hat_t(t)) { | ||
rho_hat_t(t + 1) = (rho_hat_t(t - 1) + rho_hat_t(t)) / 2; | ||
rho_hat_t(t + 2) = rho_hat_t(t + 1); | ||
} | ||
t += 2; | ||
} | ||
|
||
auto max_t = t; // max lag, used for truncation | ||
// see discussion p. 8, par "In extreme antithetic cases, " | ||
if (rho_hat_even > 0) { | ||
rho_hat_t(max_t + 1) = rho_hat_even; | ||
} | ||
|
||
double draws_total = num_chains * draws_per_chain; | ||
// eqn (13): Geyer's truncation rule, w/ modification | ||
double tau_hat = -1 + 2 * rho_hat_t.head(max_t).sum() + rho_hat_t(max_t + 1); | ||
// safety check for negative values and with max ess equal to ess*log10(ess) | ||
tau_hat = std::max(tau_hat, 1 / std::log10(draws_total)); | ||
return (draws_total / tau_hat); | ||
} | ||
|
||
} // namespace analyze | ||
} // namespace stan | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#ifndef STAN_ANALYZE_MCMC_MCSE_HPP | ||
#define STAN_ANALYZE_MCMC_MCSE_HPP | ||
|
||
#include <stan/analyze/mcmc/check_chains.hpp> | ||
#include <stan/analyze/mcmc/split_chains.hpp> | ||
#include <stan/analyze/mcmc/ess.hpp> | ||
#include <stan/math/prim.hpp> | ||
#include <cmath> | ||
#include <limits> | ||
#include <utility> | ||
|
||
namespace stan { | ||
namespace analyze { | ||
|
||
/** | ||
* Computes the mean Monte Carlo error estimate for the central 90% interval. | ||
* See https://arxiv.org/abs/1903.08008, section 4.4. | ||
* Follows implementation in the R posterior package. | ||
* | ||
* @param chains matrix of draws across all chains | ||
* @return mcse | ||
*/ | ||
inline double mcse_mean(const Eigen::MatrixXd& chains) { | ||
const Eigen::Index num_draws = chains.rows(); | ||
if (chains.rows() < 4 || !is_finite_and_varies(chains)) | ||
return std::numeric_limits<double>::quiet_NaN(); | ||
|
||
double sample_var | ||
= (chains.array() - chains.mean()).square().sum() / (chains.size() - 1); | ||
return std::sqrt(sample_var / ess(chains)); | ||
} | ||
|
||
/** | ||
* Computes the standard deviation of the Monte Carlo error estimate | ||
* https://arxiv.org/abs/1903.08008, section 4.4. | ||
* Follows implementation in the R posterior package: | ||
* https://github.com/stan-dev/posterior/blob/98bf52329d68f3307ac4ecaaea659276ee1de8df/R/convergence.R#L478-L496 | ||
* | ||
* @param chains matrix of draws across all chains | ||
* @return mcse | ||
*/ | ||
inline double mcse_sd(const Eigen::MatrixXd& chains) { | ||
if (chains.rows() < 4 || !is_finite_and_varies(chains)) | ||
return std::numeric_limits<double>::quiet_NaN(); | ||
|
||
// center the data, take abs value | ||
Eigen::MatrixXd draws_ctr = (chains.array() - chains.mean()).abs().matrix(); | ||
|
||
// posterior pkg fn `ess_mean` computes on split chains | ||
double ess_mean = ess(split_chains(draws_ctr)); | ||
|
||
// estimated variance (2nd moment) | ||
double Evar = draws_ctr.array().square().mean(); | ||
|
||
// variance of variance, adjusted for ESS | ||
double fourth_moment = draws_ctr.array().pow(4).mean(); | ||
double varvar = (fourth_moment - std::pow(Evar, 2)) / ess_mean; | ||
|
||
// variance of standard deviation - use Taylor series approximation | ||
double varsd = varvar / Evar / 4.0; | ||
return std::sqrt(varsd); | ||
} | ||
|
||
} // namespace analyze | ||
} // namespace stan | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#ifndef STAN_ANALYZE_MCMC_RHAT_HPP | ||
#define STAN_ANALYZE_MCMC_RHAT_HPP | ||
|
||
#include <stan/math/prim.hpp> | ||
#include <algorithm> | ||
#include <cmath> | ||
#include <vector> | ||
#include <limits> | ||
|
||
namespace stan { | ||
namespace analyze { | ||
|
||
/** | ||
* Computes square root of marginal posterior variance of the estimand by the | ||
* weighted average of within-chain variance W and between-chain variance B. | ||
* | ||
* @param chains stores chains in columns | ||
* @return square root of ((N-1)/N)W + B/N | ||
*/ | ||
inline double rhat(const Eigen::MatrixXd& chains) { | ||
const Eigen::Index num_chains = chains.cols(); | ||
const Eigen::Index num_draws = chains.rows(); | ||
|
||
Eigen::RowVectorXd within_chain_means = chains.colwise().mean(); | ||
double across_chain_mean = within_chain_means.mean(); | ||
double between_variance | ||
= num_draws | ||
* (within_chain_means.array() - across_chain_mean).square().sum() | ||
/ (num_chains - 1); | ||
double within_variance = | ||
// Divide each row by chains and get sum of squares for each chain | ||
// (getting a vector back) | ||
((chains.rowwise() - within_chain_means) | ||
.array() | ||
.square() | ||
.colwise() | ||
// divide each sum of square by num_draws, sum the sum of squares, | ||
// and divide by num chains | ||
.sum() | ||
/ (num_draws - 1.0)) | ||
.sum() | ||
/ num_chains; | ||
|
||
return sqrt((between_variance / within_variance + num_draws - 1) / num_draws); | ||
} | ||
|
||
} // namespace analyze | ||
} // namespace stan | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.