Skip to content

Commit

Permalink
Add laplace and pathfinder approximate inference (#624)
Browse files Browse the repository at this point in the history
* add laplace algorithm

* add pathfinder algorithm

* actually use pathfinder

* label as experimental

* fix tests
  • Loading branch information
sbfnk authored Mar 25, 2024
1 parent e857695 commit 3494e3f
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 33 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ export(setup_logging)
export(setup_target_folder)
export(simulate_infections)
export(simulate_secondary)
export(stan_laplace_opts)
export(stan_opts)
export(stan_pathfinder_opts)
export(stan_sampling_opts)
export(stan_vb_opts)
export(summarise_key_measures)
Expand Down
38 changes: 27 additions & 11 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf,
return(fit)
}

#' Fit a Stan Model using Variational Inference
#' Fit a Stan Model using an approximate method
#'
#' @description `r lifecycle::badge("maturing")`
#' Fits a stan model using variational inference.
Expand All @@ -515,7 +515,8 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf,
#' @importFrom rstan vb
#' @importFrom rlang abort
#' @return A stan model object
fit_model_with_vb <- function(args, future = FALSE, id = "stan") {
fit_model_approximate <- function(args, future = FALSE, id = "stan") {
method <- args$method
args$method <- NULL
futile.logger::flog.debug(
paste0(
Expand All @@ -536,11 +537,21 @@ fit_model_with_vb <- function(args, future = FALSE, id = "stan") {
trials <- 1
}

fit_vb <- function(stan_args) {
fit_approximate <- function(stan_args) {
if (inherits(stan_args$object, "stanmodel")) {
sample_func <- rstan::vb
if (method == "vb") {
sample_func <- rstan::vb
} else {
stop("Laplace approximation only available in the cmdstanr backend")
}
} else if (inherits(stan_args$object, "CmdStanModel")) {
sample_func <- stan_args$object$variational
if (method == "vb") {
sample_func <- stan_args$object$variational
} else if (method == "laplace") {
sample_func <- stan_args$object$laplace
} else {
sample_func <- stan_args$object$pathfinder
}
stan_args$object <- NULL
}
fit <- do.call(sample_func, stan_args)
Expand All @@ -552,31 +563,36 @@ fit_model_with_vb <- function(args, future = FALSE, id = "stan") {
}
return(fit)
}
safe_vb <- purrr::safely(fit_vb) # nolint
safe_fit <- purrr::safely(fit_approximate) # nolint
fit <- NULL
current_trials <- 0

while (current_trials <= trials && is.null(fit)) {
fit <- safe_vb(args)
fit <- safe_fit(args)

error <- fit[[2]]
fit <- fit[[1]]
if (is(fit, "CmdStanFit") && fit$return_codes() > 0) {
error <- tail(capture.output(fit$output()), 1)
fit <- NULL
}
current_trials <- current_trials + 1
}

if (is.null(fit)) {
futile.logger::flog.error(
"%s: Fitting failed - try increasing stan_args$trials or inspecting",
" the model input",
paste(
"%s: Fitting failed - try increasing stan_args$trials or inspecting",
"the model input"
),
id,
name = "EpiNow2.epinow.estimate_infections.fit"
)
rlang::abort("Variational Inference failed due to: ", error)
rlang::abort(paste("Approximate inference failed due to:", error))
}
return(fit)
}


#' Format Posterior Samples
#'
#' @description `r lifecycle::badge("stable")`
Expand Down
5 changes: 3 additions & 2 deletions R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ extract_samples <- function(stan_fit, pars = NULL, include = TRUE) {
if (!is.null(pars)) args <- c(args, list(pars = pars))
return(do.call(rstan::extract, args))
}
if (!inherits(stan_fit, "CmdStanMCMC")) {
stop("stan_fit must be a <stanfit> or <CmdStanMCMC> object")
if (!inherits(stan_fit, "CmdStanMCMC") &&
!inherits(stan_fit, "CmdStanFit")) {
stop("stan_fit must be a <stanfit>, <CmdStanMCMC> or <CmdStanFit> object")
}

# extract sample from stan object
Expand Down
72 changes: 69 additions & 3 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,61 @@ stan_vb_opts <- function(samples = 2000,
return(opts)
}

#' Stan Laplace algorithm Options
#'
#' @description `r lifecycle::badge("experimental")`
#' Defines a list specifying the arguments passed to [cmdstanr::laplace()].
#'
#' @inheritParams stan_opts
#' @inheritParams stan_vb_opts
#' @param ... Additional parameters to pass to [cmdstanr::laplace()].
#' @return A list of arguments to pass to [cmdstanr::laplace()].
#' @export
#' @examples
#' stan_laplace_opts()
stan_laplace_opts <- function(backend = "cmdstanr",
trials = 10,
...) {
if (backend != "cmdstanr") {
stop(
"The Laplace algorithm is only available with the \"cmdstanr\" backend."
)
}
opts <- list(trials = trials)
opts <- c(opts, ...)
return(opts)
}

#' Stan pathfinder algorithm Options
#'
#' @description `r lifecycle::badge("experimental")`
#' Defines a list specifying the arguments passed to [cmdstanr::laplace()].
#'
#' @inheritParams stan_opts
#' @inheritParams stan_vb_opts
#' @param ... Additional parameters to pass to [cmdstanr::laplace()].
#' @return A list of arguments to pass to [cmdstanr::laplace()].
#' @export
#' @examples
#' stan_laplace_opts()
stan_pathfinder_opts <- function(backend = "cmdstanr",
samples = 2000,
trials = 10,
...) {
if (backend != "cmdstanr") {
stop(
"The pathfinder algorithm is only available with the \"cmdstanr\" ",
"backend."
)
}
opts <- list(
trials = trials,
draws = samples
)
opts <- c(opts, ...)
return(opts)
}

#' Rstan Options
#'
#' @description `r lifecycle::badge("deprecated")`
Expand All @@ -788,7 +843,7 @@ stan_vb_opts <- function(samples = 2000,
#' default.
#'
#' @param method A character string, defaulting to sampling. Currently supports
#' [rstan::sampling()] ("sampling") or [rstan::vb()] ("vb").
#' [rstan::sampling()] ("sampling") or [rstan::vb()].
#'
#' @param ... Additional parameters to pass underlying option functions.
#' @importFrom rlang arg_match
Expand Down Expand Up @@ -839,7 +894,9 @@ rstan_opts <- function(object = NULL,
#'
#' @param method A character string, defaulting to sampling. Currently supports
#' MCMC sampling ("sampling") or approximate posterior sampling via
#' variational inference ("vb").
#' variational inference ("vb") and, as experimental features if the
#' "cmdstanr" backend is used, approximate posterior sampling with the
#' laplace algorithm ("laplace") or pathfinder ("pathfinder").
#'
#' @param backend Character string indicating the backend to use for fitting
#' stan models. Supported arguments are "rstan" (default) or "cmdstanr".
Expand Down Expand Up @@ -880,7 +937,7 @@ rstan_opts <- function(object = NULL,
#' stan_opts(method = "vb")
stan_opts <- function(object = NULL,
samples = 2000,
method = c("sampling", "vb"),
method = c("sampling", "vb", "laplace", "pathfinder"),
backend = c("rstan", "cmdstanr"),
init_fit = NULL,
return_fit = TRUE,
Expand Down Expand Up @@ -922,7 +979,16 @@ stan_opts <- function(object = NULL,
)
} else if (method == "vb") {
opts <- c(opts, stan_vb_opts(samples = samples, ...))
} else if (method == "laplace") {
opts <- c(
opts, stan_laplace_opts(backend = backend, ...)
)
} else if (method == "pathfinder") {
opts <- c(
opts, stan_pathfinder_opts(samples = samples, backend = backend, ...)
)
}

if (!is.null(init_fit)) {
deprecate_warn(
when = "1.5.0",
Expand Down
6 changes: 3 additions & 3 deletions R/stan.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ fit_model <- function(args, id = "stan") {
future = args$future,
max_execution_time = args$max_execution_time, id = id
)
} else if (args$method == "vb") {
fit <- fit_model_with_vb(args, id = id)
} else if (args$method %in% c("vb", "laplace", "pathfinder")) {
fit <- fit_model_approximate(args, id = id)
} else {
stop("args$method must be one of 'sampling' or 'vb'")
stop("method ", args$method, " unknown")
}
return(fit)
}
2 changes: 1 addition & 1 deletion inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ parameters{

transformed parameters {
vector[fixed ? 0 : noise_terms] noise; // noise generated by the gaussian process
vector<lower = 0, upper = 10 * r_mean>[estimate_r > 0 ? ot_h : 0] R; // reproduction number
vector<lower = 0>[estimate_r > 0 ? ot_h : 0] R; // reproduction number
vector[t] infections; // latent infections
vector[ot_h] reports; // estimated reported cases
vector[ot] obs_reports; // observed estimated reported cases
Expand Down
8 changes: 4 additions & 4 deletions man/fit_model_with_vb.Rd → man/fit_model_approximate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/rstan_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions man/stan_laplace_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions man/stan_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions man/stan_pathfinder_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 3494e3f

Please sign in to comment.