From 3494e3f6609916f61562e866ac386d3c225ed00f Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Mon, 25 Mar 2024 20:00:00 +0000 Subject: [PATCH] Add laplace and pathfinder approximate inference (#624) * add laplace algorithm * add pathfinder algorithm * actually use pathfinder * label as experimental * fix tests --- NAMESPACE | 2 + R/estimate_infections.R | 38 +++++++--- R/extract.R | 5 +- R/opts.R | 72 ++++++++++++++++++- R/stan.R | 6 +- inst/stan/estimate_infections.stan | 2 +- ...el_with_vb.Rd => fit_model_approximate.Rd} | 8 +-- man/rstan_opts.Rd | 2 +- man/stan_laplace_opts.Rd | 27 +++++++ man/stan_opts.Rd | 6 +- man/stan_pathfinder_opts.Rd | 30 ++++++++ tests/testthat/test-epinow.R | 53 ++++++++++++-- 12 files changed, 218 insertions(+), 33 deletions(-) rename man/{fit_model_with_vb.Rd => fit_model_approximate.Rd} (81%) create mode 100644 man/stan_laplace_opts.Rd create mode 100644 man/stan_pathfinder_opts.Rd diff --git a/NAMESPACE b/NAMESPACE index a64aaed84..e78bed18b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -104,7 +104,9 @@ export(setup_logging) export(setup_target_folder) export(simulate_infections) export(simulate_secondary) +export(stan_laplace_opts) export(stan_opts) +export(stan_pathfinder_opts) export(stan_sampling_opts) export(stan_vb_opts) export(summarise_key_measures) diff --git a/R/estimate_infections.R b/R/estimate_infections.R index b69390ed9..75d6845cd 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -504,7 +504,7 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf, return(fit) } -#' Fit a Stan Model using Variational Inference +#' Fit a Stan Model using an approximate method #' #' @description `r lifecycle::badge("maturing")` #' Fits a stan model using variational inference. @@ -515,7 +515,8 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf, #' @importFrom rstan vb #' @importFrom rlang abort #' @return A stan model object -fit_model_with_vb <- function(args, future = FALSE, id = "stan") { +fit_model_approximate <- function(args, future = FALSE, id = "stan") { + method <- args$method args$method <- NULL futile.logger::flog.debug( paste0( @@ -536,11 +537,21 @@ fit_model_with_vb <- function(args, future = FALSE, id = "stan") { trials <- 1 } - fit_vb <- function(stan_args) { + fit_approximate <- function(stan_args) { if (inherits(stan_args$object, "stanmodel")) { - sample_func <- rstan::vb + if (method == "vb") { + sample_func <- rstan::vb + } else { + stop("Laplace approximation only available in the cmdstanr backend") + } } else if (inherits(stan_args$object, "CmdStanModel")) { - sample_func <- stan_args$object$variational + if (method == "vb") { + sample_func <- stan_args$object$variational + } else if (method == "laplace") { + sample_func <- stan_args$object$laplace + } else { + sample_func <- stan_args$object$pathfinder + } stan_args$object <- NULL } fit <- do.call(sample_func, stan_args) @@ -552,31 +563,36 @@ fit_model_with_vb <- function(args, future = FALSE, id = "stan") { } return(fit) } - safe_vb <- purrr::safely(fit_vb) # nolint + safe_fit <- purrr::safely(fit_approximate) # nolint fit <- NULL current_trials <- 0 while (current_trials <= trials && is.null(fit)) { - fit <- safe_vb(args) + fit <- safe_fit(args) error <- fit[[2]] fit <- fit[[1]] + if (is(fit, "CmdStanFit") && fit$return_codes() > 0) { + error <- tail(capture.output(fit$output()), 1) + fit <- NULL + } current_trials <- current_trials + 1 } if (is.null(fit)) { futile.logger::flog.error( - "%s: Fitting failed - try increasing stan_args$trials or inspecting", - " the model input", + paste( + "%s: Fitting failed - try increasing stan_args$trials or inspecting", + "the model input" + ), id, name = "EpiNow2.epinow.estimate_infections.fit" ) - rlang::abort("Variational Inference failed due to: ", error) + rlang::abort(paste("Approximate inference failed due to:", error)) } return(fit) } - #' Format Posterior Samples #' #' @description `r lifecycle::badge("stable")` diff --git a/R/extract.R b/R/extract.R index 838003bce..e3ad05c9c 100644 --- a/R/extract.R +++ b/R/extract.R @@ -72,8 +72,9 @@ extract_samples <- function(stan_fit, pars = NULL, include = TRUE) { if (!is.null(pars)) args <- c(args, list(pars = pars)) return(do.call(rstan::extract, args)) } - if (!inherits(stan_fit, "CmdStanMCMC")) { - stop("stan_fit must be a or object") + if (!inherits(stan_fit, "CmdStanMCMC") && + !inherits(stan_fit, "CmdStanFit")) { + stop("stan_fit must be a , or object") } # extract sample from stan object diff --git a/R/opts.R b/R/opts.R index b09e7b2fb..7059ee239 100644 --- a/R/opts.R +++ b/R/opts.R @@ -779,6 +779,61 @@ stan_vb_opts <- function(samples = 2000, return(opts) } +#' Stan Laplace algorithm Options +#' +#' @description `r lifecycle::badge("experimental")` +#' Defines a list specifying the arguments passed to [cmdstanr::laplace()]. +#' +#' @inheritParams stan_opts +#' @inheritParams stan_vb_opts +#' @param ... Additional parameters to pass to [cmdstanr::laplace()]. +#' @return A list of arguments to pass to [cmdstanr::laplace()]. +#' @export +#' @examples +#' stan_laplace_opts() +stan_laplace_opts <- function(backend = "cmdstanr", + trials = 10, + ...) { + if (backend != "cmdstanr") { + stop( + "The Laplace algorithm is only available with the \"cmdstanr\" backend." + ) + } + opts <- list(trials = trials) + opts <- c(opts, ...) + return(opts) +} + +#' Stan pathfinder algorithm Options +#' +#' @description `r lifecycle::badge("experimental")` +#' Defines a list specifying the arguments passed to [cmdstanr::laplace()]. +#' +#' @inheritParams stan_opts +#' @inheritParams stan_vb_opts +#' @param ... Additional parameters to pass to [cmdstanr::laplace()]. +#' @return A list of arguments to pass to [cmdstanr::laplace()]. +#' @export +#' @examples +#' stan_laplace_opts() +stan_pathfinder_opts <- function(backend = "cmdstanr", + samples = 2000, + trials = 10, + ...) { + if (backend != "cmdstanr") { + stop( + "The pathfinder algorithm is only available with the \"cmdstanr\" ", + "backend." + ) + } + opts <- list( + trials = trials, + draws = samples + ) + opts <- c(opts, ...) + return(opts) +} + #' Rstan Options #' #' @description `r lifecycle::badge("deprecated")` @@ -788,7 +843,7 @@ stan_vb_opts <- function(samples = 2000, #' default. #' #' @param method A character string, defaulting to sampling. Currently supports -#' [rstan::sampling()] ("sampling") or [rstan::vb()] ("vb"). +#' [rstan::sampling()] ("sampling") or [rstan::vb()]. #' #' @param ... Additional parameters to pass underlying option functions. #' @importFrom rlang arg_match @@ -839,7 +894,9 @@ rstan_opts <- function(object = NULL, #' #' @param method A character string, defaulting to sampling. Currently supports #' MCMC sampling ("sampling") or approximate posterior sampling via -#' variational inference ("vb"). +#' variational inference ("vb") and, as experimental features if the +#' "cmdstanr" backend is used, approximate posterior sampling with the +#' laplace algorithm ("laplace") or pathfinder ("pathfinder"). #' #' @param backend Character string indicating the backend to use for fitting #' stan models. Supported arguments are "rstan" (default) or "cmdstanr". @@ -880,7 +937,7 @@ rstan_opts <- function(object = NULL, #' stan_opts(method = "vb") stan_opts <- function(object = NULL, samples = 2000, - method = c("sampling", "vb"), + method = c("sampling", "vb", "laplace", "pathfinder"), backend = c("rstan", "cmdstanr"), init_fit = NULL, return_fit = TRUE, @@ -922,7 +979,16 @@ stan_opts <- function(object = NULL, ) } else if (method == "vb") { opts <- c(opts, stan_vb_opts(samples = samples, ...)) + } else if (method == "laplace") { + opts <- c( + opts, stan_laplace_opts(backend = backend, ...) + ) + } else if (method == "pathfinder") { + opts <- c( + opts, stan_pathfinder_opts(samples = samples, backend = backend, ...) + ) } + if (!is.null(init_fit)) { deprecate_warn( when = "1.5.0", diff --git a/R/stan.R b/R/stan.R index d01d4e04b..0cb500678 100644 --- a/R/stan.R +++ b/R/stan.R @@ -83,10 +83,10 @@ fit_model <- function(args, id = "stan") { future = args$future, max_execution_time = args$max_execution_time, id = id ) - } else if (args$method == "vb") { - fit <- fit_model_with_vb(args, id = id) + } else if (args$method %in% c("vb", "laplace", "pathfinder")) { + fit <- fit_model_approximate(args, id = id) } else { - stop("args$method must be one of 'sampling' or 'vb'") + stop("method ", args$method, " unknown") } return(fit) } diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index a1eae062a..4bd1b8fb2 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -60,7 +60,7 @@ parameters{ transformed parameters { vector[fixed ? 0 : noise_terms] noise; // noise generated by the gaussian process - vector[estimate_r > 0 ? ot_h : 0] R; // reproduction number + vector[estimate_r > 0 ? ot_h : 0] R; // reproduction number vector[t] infections; // latent infections vector[ot_h] reports; // estimated reported cases vector[ot] obs_reports; // observed estimated reported cases diff --git a/man/fit_model_with_vb.Rd b/man/fit_model_approximate.Rd similarity index 81% rename from man/fit_model_with_vb.Rd rename to man/fit_model_approximate.Rd index 262b2b435..e5a2e9fc9 100644 --- a/man/fit_model_with_vb.Rd +++ b/man/fit_model_approximate.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/estimate_infections.R -\name{fit_model_with_vb} -\alias{fit_model_with_vb} -\title{Fit a Stan Model using Variational Inference} +\name{fit_model_approximate} +\alias{fit_model_approximate} +\title{Fit a Stan Model using an approximate method} \usage{ -fit_model_with_vb(args, future = FALSE, id = "stan") +fit_model_approximate(args, future = FALSE, id = "stan") } \arguments{ \item{args}{List of stan arguments.} diff --git a/man/rstan_opts.Rd b/man/rstan_opts.Rd index ad80dba9a..f13484bd4 100644 --- a/man/rstan_opts.Rd +++ b/man/rstan_opts.Rd @@ -14,7 +14,7 @@ default.} When using multiple chains iterations per chain is samples / chains.} \item{method}{A character string, defaulting to sampling. Currently supports -\code{\link[rstan:stanmodel-method-sampling]{rstan::sampling()}} ("sampling") or \code{\link[rstan:stanmodel-method-vb]{rstan::vb()}} ("vb").} +\code{\link[rstan:stanmodel-method-sampling]{rstan::sampling()}} ("sampling") or \code{\link[rstan:stanmodel-method-vb]{rstan::vb()}}.} \item{...}{Additional parameters to pass underlying option functions.} } diff --git a/man/stan_laplace_opts.Rd b/man/stan_laplace_opts.Rd new file mode 100644 index 000000000..6e415d562 --- /dev/null +++ b/man/stan_laplace_opts.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/opts.R +\name{stan_laplace_opts} +\alias{stan_laplace_opts} +\title{Stan Laplace algorithm Options} +\usage{ +stan_laplace_opts(backend = "cmdstanr", trials = 10, ...) +} +\arguments{ +\item{backend}{Character string indicating the backend to use for fitting +stan models. Supported arguments are "rstan" (default) or "cmdstanr".} + +\item{trials}{Numeric, defaults to 10. Number of attempts to use +rstan::vb()] before failing.} + +\item{...}{Additional parameters to pass to \code{\link[cmdstanr:model-method-laplace]{cmdstanr::laplace()}}.} +} +\value{ +A list of arguments to pass to \code{\link[cmdstanr:model-method-laplace]{cmdstanr::laplace()}}. +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} +Defines a list specifying the arguments passed to \code{\link[cmdstanr:model-method-laplace]{cmdstanr::laplace()}}. +} +\examples{ +stan_laplace_opts() +} diff --git a/man/stan_opts.Rd b/man/stan_opts.Rd index d5473d72b..694418b39 100644 --- a/man/stan_opts.Rd +++ b/man/stan_opts.Rd @@ -7,7 +7,7 @@ stan_opts( object = NULL, samples = 2000, - method = c("sampling", "vb"), + method = c("sampling", "vb", "laplace", "pathfinder"), backend = c("rstan", "cmdstanr"), init_fit = NULL, return_fit = TRUE, @@ -26,7 +26,9 @@ When using multiple chains iterations per chain is samples / chains.} \item{method}{A character string, defaulting to sampling. Currently supports MCMC sampling ("sampling") or approximate posterior sampling via -variational inference ("vb").} +variational inference ("vb") and, as experimental features if the +"cmdstanr" backend is used, approximate posterior sampling with the +laplace algorithm ("laplace") or pathfinder ("pathfinder").} \item{backend}{Character string indicating the backend to use for fitting stan models. Supported arguments are "rstan" (default) or "cmdstanr".} diff --git a/man/stan_pathfinder_opts.Rd b/man/stan_pathfinder_opts.Rd new file mode 100644 index 000000000..2a2c13395 --- /dev/null +++ b/man/stan_pathfinder_opts.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/opts.R +\name{stan_pathfinder_opts} +\alias{stan_pathfinder_opts} +\title{Stan pathfinder algorithm Options} +\usage{ +stan_pathfinder_opts(backend = "cmdstanr", samples = 2000, trials = 10, ...) +} +\arguments{ +\item{backend}{Character string indicating the backend to use for fitting +stan models. Supported arguments are "rstan" (default) or "cmdstanr".} + +\item{samples}{Numeric, default 2000. Overall number of posterior samples. +When using multiple chains iterations per chain is samples / chains.} + +\item{trials}{Numeric, defaults to 10. Number of attempts to use +rstan::vb()] before failing.} + +\item{...}{Additional parameters to pass to \code{\link[cmdstanr:model-method-laplace]{cmdstanr::laplace()}}.} +} +\value{ +A list of arguments to pass to \code{\link[cmdstanr:model-method-laplace]{cmdstanr::laplace()}}. +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} +Defines a list specifying the arguments passed to \code{\link[cmdstanr:model-method-laplace]{cmdstanr::laplace()}}. +} +\examples{ +stan_laplace_opts() +} diff --git a/tests/testthat/test-epinow.R b/tests/testthat/test-epinow.R index 83dd87c69..dc644afbe 100644 --- a/tests/testthat/test-epinow.R +++ b/tests/testthat/test-epinow.R @@ -46,12 +46,7 @@ test_that("epinow produces expected output when run with the reported_cases = reported_cases, generation_time = generation_time_opts(example_generation_time), delays = delay_opts(example_incubation_period + reporting_delay), - stan = stan_opts( - samples = 25, warmup = 25, - cores = 1, chains = 2, - control = list(adapt_delta = 0.8), - backend = "cmdstanr" - ), + stan = stan_opts(backend = "cmdstanr"), logs = NULL, verbose = FALSE ) ))) @@ -67,6 +62,52 @@ test_that("epinow produces expected output when run with the ) }) +test_that("epinow produces expected output when run with the + laplace algorithm", { + skip_on_os("windows") + output <- capture.output(suppressMessages(suppressWarnings( + out <- epinow( + reported_cases = reported_cases, + generation_time = generation_time_opts(example_generation_time), + delays = delay_opts(example_incubation_period + reporting_delay), + stan = stan_opts(method = "laplace", backend = "cmdstanr"), + logs = NULL, verbose = FALSE + ) + ))) + expect_equal(names(out), expected_out) + df_non_zero(out$estimates$samples) + df_non_zero(out$estimates$summarised) + df_non_zero(out$estimated_reported_cases$samples) + df_non_zero(out$estimated_reported_cases$summarised) + df_non_zero(out$summary) + expect_equal( + names(out$plots), c("summary", "infections", "reports", "R", "growth_rate") + ) +}) + +test_that("epinow produces expected output when run with the + pathfinder algorithm", { + skip_on_os("windows") + output <- capture.output(suppressMessages(suppressWarnings( + out <- epinow( + reported_cases = reported_cases, + generation_time = generation_time_opts(example_generation_time), + delays = delay_opts(example_incubation_period + reporting_delay), + stan = stan_opts(method = "pathfinder", backend = "cmdstanr"), + logs = NULL, verbose = FALSE + ) + ))) + expect_equal(names(out), expected_out) + df_non_zero(out$estimates$samples) + df_non_zero(out$estimates$summarised) + df_non_zero(out$estimated_reported_cases$samples) + df_non_zero(out$estimated_reported_cases$summarised) + df_non_zero(out$summary) + expect_equal( + names(out$plots), c("summary", "infections", "reports", "R", "growth_rate") + ) +}) + test_that("epinow runs without error when saving to disk", { expect_null(suppressWarnings(epinow( reported_cases = reported_cases,