diff --git a/NEWS.md b/NEWS.md index 14be5a6..b425c09 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,8 @@ Items for next release go here +Extended `loo_pit()` for discrete data. + # rstantools 2.4.0 * Update to match CRAN's patched version by @jgabry in #114 diff --git a/R/loo-functions.R b/R/loo-functions.R index 056ea54..184dd51 100644 --- a/R/loo-functions.R +++ b/R/loo-functions.R @@ -10,12 +10,15 @@ #' @return `loo_predict()`, `loo_linpred()`, and `loo_pit()` #' (probability integral transform) methods should return a vector with length #' equal to the number of observations in the data. +#' For discrete observations, probability integral transform is randomised to +#' ensure theoretical uniformity. Fix random seed for reproducible results +#' with discrete data. For more details, see Czado et al. (2009). #' `loo_predictive_interval()` methods should return a two-column matrix #' formatted in the same way as for [predictive_interval()]. #' #' @template seealso-rstanarm-pkg #' @template seealso-vignettes -#' +#' @template reference-randomised-pit #' @rdname loo-prediction #' @export @@ -60,11 +63,32 @@ loo_pit.default <- function(object, y, lw, ...) { # internal ---------------------------------------------------------------- .loo_pit <- function(y, yrep, lw) { - vapply(seq_len(ncol(yrep)), function(j) { - sel <- yrep[, j] <= y[j] - .exp_log_sum_exp(lw[sel, j]) + if (is.null(lw) || !all(is.finite(lw))) { + stop("lw needs to be not null and finite.") + } + pits <- vapply(seq_len(ncol(yrep)), function(j) { + sel_min <- yrep[, j] < y[j] + pit <- .exp_log_sum_exp(lw[sel_min, j]) + sel_sup <- yrep[, j] == y[j] + if (any(sel_sup)) { + # randomized PIT for discrete y (see, e.g., Czado, C., Gneiting, T., + # Held, L.: Predictive model assessment for count data. + # Biometrics 65(4), 1254–1261 (2009).) + pit_sup <- pit + .exp_log_sum_exp(lw[sel_sup, j]) + pit <- stats::runif(1, pit, pit_sup) + } + pit }, FUN.VALUE = 1) + if (any(pits > 1)) { + warning(cat( + "Some PIT values larger than 1! Largest: ", + max(pits), + "\nRounding PIT > 1 to 1." + )) + } + pmin(1, pits) } + .exp_log_sum_exp <- function(x) { m <- suppressWarnings(max(x)) exp(m + log(sum(exp(x - m)))) diff --git a/man-roxygen/details-license.R b/man-roxygen/details-license.R index e118853..009b90f 100644 --- a/man-roxygen/details-license.R +++ b/man-roxygen/details-license.R @@ -1,6 +1,6 @@ #' @details In order to enable Stan functionality, \pkg{\link{rstantools}} -#' copies some files to your package. Since these files are licensed as GPL -#' >= 3, the same license applies to your package should you choose to +#' copies some files to your package. Since these files are licensed as +#' GPL >= 3, the same license applies to your package should you choose to #' distribute it. Even if you don't use \pkg{\link{rstantools}} to create #' your package, it is likely that you will be linking to \pkg{\link{Rcpp}} to #' export the Stan C++ `stanmodel` objects to \R. Since diff --git a/man-roxygen/reference-randomised-pit.R b/man-roxygen/reference-randomised-pit.R new file mode 100644 index 0000000..8e9a48a --- /dev/null +++ b/man-roxygen/reference-randomised-pit.R @@ -0,0 +1,5 @@ +#' @references Czado, C., Gneiting, T., and Held, L. (2009). +#' Predictive Model Assessment for Count Data. +#' *Biometrics*. 65(4), 1254-1261. +#' doi:10.1111/j.1541-0420.2009.01191.x. +#' Journal version: diff --git a/man/loo-prediction.Rd b/man/loo-prediction.Rd index 9655a59..dda7866 100644 --- a/man/loo-prediction.Rd +++ b/man/loo-prediction.Rd @@ -35,12 +35,22 @@ the same length as the number of columns in the matrix used as \code{object}.} \code{loo_predict()}, \code{loo_linpred()}, and \code{loo_pit()} (probability integral transform) methods should return a vector with length equal to the number of observations in the data. +For discrete observations, probability integral transform is randomised to +ensure theoretical uniformity. Fix random seed for reproducible results +with discrete data. For more details, see Czado et al. (2009). \code{loo_predictive_interval()} methods should return a two-column matrix formatted in the same way as for \code{\link[=predictive_interval]{predictive_interval()}}. } \description{ See the methods in the \pkg{rstanarm} package for examples. } +\references{ +Czado, C., Gneiting, T., and Held, L. (2009). +Predictive Model Assessment for Count Data. +\emph{Biometrics}. 65(4), 1254-1261. +doi:10.1111/j.1541-0420.2009.01191.x. +Journal version: \url{https://doi.org/10.1111/j.1541-0420.2009.01191.x} +} \seealso{ \itemize{ \item The \pkg{rstanarm} package (\href{https://mc-stan.org/rstanarm/}{mc-stan.org/rstanarm}) diff --git a/man/rstan_create_package.Rd b/man/rstan_create_package.Rd index c4184ad..c8b048e 100644 --- a/man/rstan_create_package.Rd +++ b/man/rstan_create_package.Rd @@ -114,7 +114,13 @@ must be manually configured by running \code{\link[=rstan_config]{rstan_config() \code{stanmodel} files in \code{inst/stan} are added, removed, or modified. In order to enable Stan functionality, \pkg{\link{rstantools}} -copies some files to your package. Since these files are licensed as GPL= 3, the same license applies to your package should you choose todistribute it. Even if you don't use \pkg{\link{rstantools}} to createyour package, it is likely that you will be linking to \pkg{\link{Rcpp}} toexport the Stan C++ stanmodel objects to \R. Since\pkg{\link{Rcpp}} is released under GPL >= 2, the same license would applyto your package upon distribution. +copies some files to your package. Since these files are licensed as +GPL >= 3, the same license applies to your package should you choose to +distribute it. Even if you don't use \pkg{\link{rstantools}} to create +your package, it is likely that you will be linking to \pkg{\link{Rcpp}} to +export the Stan C++ \code{stanmodel} objects to \R. Since +\pkg{\link{Rcpp}} is released under GPL >= 2, the same license would apply +to your package upon distribution. Authors willing to license their Stan programs of general interest under the GPL are invited to contribute their \code{.stan} files and diff --git a/tests/testthat/loo_pit.RDS b/tests/testthat/loo_pit.RDS index b135a07..2b9c1fc 100644 Binary files a/tests/testthat/loo_pit.RDS and b/tests/testthat/loo_pit.RDS differ diff --git a/tests/testthat/loo_pit_discrete.RDS b/tests/testthat/loo_pit_discrete.RDS new file mode 100644 index 0000000..7eb82bf Binary files /dev/null and b/tests/testthat/loo_pit_discrete.RDS differ diff --git a/tests/testthat/test-default-methods.R b/tests/testthat/test-default-methods.R index d4ed2ca..0548b34 100644 --- a/tests/testthat/test-default-methods.R +++ b/tests/testthat/test-default-methods.R @@ -2,6 +2,13 @@ context("default methods") set.seed(1111) x <- matrix(rnorm(150), 50, 3) y <- rnorm(ncol(x)) +lw <- matrix(rnorm(150), 50, 3) +lw <- sweep( + lw, + MARGIN = 2, + STATS = apply(lw, 2, \(col) log(sum(exp(col)))), + check.margin = FALSE +) test_that("posterior_interval.default hasn't changed", { expect_equal_to_reference( @@ -28,12 +35,18 @@ test_that("prior_summary.default works", { expect_null(prior_summary(list(abc = "prior_info"))) }) test_that("loo_pit.default works", { - lw <- matrix(rnorm(150), 50, 3) expect_equal_to_reference( loo_pit(x, y, lw), "loo_pit.RDS" ) }) +test_that("loo_pit-default works for discrete data", { + set.seed(1111) + expect_equal_to_reference( + loo_pit(round(x), round(y), lw), + "loo_pit_discrete.RDS" + ) +}) test_that("bayes_R2.default hasn't changed", { expect_equal_to_reference( bayes_R2(x, y), @@ -82,4 +95,3 @@ test_that(".pred_errors throws errors", { fixed = TRUE) }) -