diff --git a/.github/workflows/R-CMD-as-cran-check.yaml b/.github/workflows/R-CMD-as-cran-check.yaml index 16104203f..6707fa6fd 100644 --- a/.github/workflows/R-CMD-as-cran-check.yaml +++ b/.github/workflows/R-CMD-as-cran-check.yaml @@ -36,9 +36,18 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: any::rcmdcheck - needs: check + dependencies: NA + extra-packages: | + rcmdcheck + stan-dev/cmdstanr + testthat + + - name: Install cmdstan + run: | + cmdstanr::check_cmdstan_toolchain(fix = TRUE) + cmdstanr::install_cmdstan(cores = 2, quiet = TRUE) + shell: Rscript {0} - uses: r-lib/actions/check-r-package@v2 with: - upload-snapshots: true \ No newline at end of file + upload-snapshots: true diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index bf666a780..7f5bd6356 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -47,13 +47,6 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Install cmdstan Linux system dependencies - if: runner.os == 'Linux' - run: | - sudo apt-get update - sudo apt-get install -y libcurl4-openssl-dev || true - sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev || true - sudo apt-get install -y libpng-dev || true - uses: r-lib/actions/setup-pandoc@v2 - uses: r-lib/actions/setup-r@v2 @@ -65,8 +58,20 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: any::rcmdcheck - needs: check + dependencies: NA + extra-packages: | + dplyr + rmarkdown + rcmdcheck + stan-dev/cmdstanr + testthat + + - name: Install cmdstan + if: runner.os != 'Windows' + run: | + cmdstanr::check_cmdstan_toolchain(fix = TRUE) + cmdstanr::install_cmdstan(cores = 2, quiet = TRUE) + shell: Rscript {0} - uses: r-lib/actions/check-r-package@v2 with: diff --git a/.github/workflows/lint-only-changed-files.yaml b/.github/workflows/lint-only-changed-files.yaml index bdceceaeb..605b2db69 100644 --- a/.github/workflows/lint-only-changed-files.yaml +++ b/.github/workflows/lint-only-changed-files.yaml @@ -21,11 +21,12 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: + dependencies: NA extra-packages: | + stan-dev/cmdstanr any::gh any::lintr any::purrr - needs: check - name: Add lintr options run: | @@ -44,4 +45,4 @@ jobs: lintr::lint_package(exclusions = exclusions_list) shell: Rscript {0} env: - LINTR_ERROR_ON_LINT: true \ No newline at end of file + LINTR_ERROR_ON_LINT: true diff --git a/.github/workflows/synthetic-validation.yaml b/.github/workflows/synthetic-validation.yaml index 30671cfc4..346da30a4 100644 --- a/.github/workflows/synthetic-validation.yaml +++ b/.github/workflows/synthetic-validation.yaml @@ -24,7 +24,9 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: + dependencies: NA extra-packages: | + here dplyr tidyr scoringutils @@ -48,4 +50,4 @@ jobs: with: name: fits retention-days: 5 - path: synthetic.rds \ No newline at end of file + path: synthetic.rds diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index 664af0f84..f352c2870 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -26,9 +26,18 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: any::covr - needs: coverage + dependencies: NA + extra-packages: | + covr + stan-dev/cmdstanr + testthat + + - name: Install cmdstan + run: | + cmdstanr::check_cmdstan_toolchain(fix = TRUE) + cmdstanr::install_cmdstan(cores = 2, quiet = TRUE) + shell: Rscript {0} - name: Test coverage run: covr::codecov(quiet = FALSE) - shell: Rscript {0} \ No newline at end of file + shell: Rscript {0} diff --git a/.gitignore b/.gitignore index 66554b5f0..3df7521ab 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,8 @@ vignettes/results # unused figures man/figures/*.png + +# exclude compiled stan files +inst/stan/* +!inst/stan/*/ +!inst/stan/*.stan diff --git a/DESCRIPTION b/DESCRIPTION index a675ac158..4b78ec9ef 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -109,6 +109,7 @@ Imports: lubridate, methods, patchwork, + posterior, progressr, purrr, R.utils (>= 2.0.0), @@ -122,6 +123,7 @@ Imports: truncnorm, utils Suggests: + cmdstanr, covr, dplyr, here, @@ -143,13 +145,15 @@ LinkingTo: RcppParallel (>= 5.0.1), rstan (>= 2.26.0), StanHeaders (>= 2.26.0) +Additional_repositories: + https://mc-stan.org/r-packages/ Biarch: true Config/testthat/edition: 3 Encoding: UTF-8 Language: en-GB LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 NeedsCompilation: yes SystemRequirements: GNU make C++17 diff --git a/NAMESPACE b/NAMESPACE index 2a6d95c3c..9041b04eb 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -46,6 +46,7 @@ export(estimates_by_report_date) export(expose_stan_fns) export(extract_CrIs) export(extract_inits) +export(extract_samples) export(extract_stan_param) export(fix_dist) export(forecast_infections) @@ -66,6 +67,7 @@ export(make_conf) export(map_prob_change) export(obs_opts) export(opts_list) +export(package_model) export(plot_estimates) export(plot_summary) export(regional_epinow) @@ -91,6 +93,8 @@ export(setup_target_folder) export(simulate_infections) export(simulate_secondary) export(stan_opts) +export(stan_sampling_opts) +export(stan_vb_opts) export(summarise_key_measures) export(summarise_results) export(trunc_opts) @@ -133,6 +137,7 @@ importFrom(data.table,rbindlist) importFrom(data.table,setDT) importFrom(data.table,setDTthreads) importFrom(data.table,setcolorder) +importFrom(data.table,setkey) importFrom(data.table,setnames) importFrom(data.table,setorder) importFrom(data.table,setorderv) @@ -184,6 +189,7 @@ importFrom(lifecycle,deprecate_warn) importFrom(lubridate,days) importFrom(lubridate,wday) importFrom(patchwork,plot_layout) +importFrom(posterior,mcse_mean) importFrom(progressr,progressor) importFrom(progressr,with_progress) importFrom(purrr,compact) diff --git a/R/create.R b/R/create.R index 1b73f26e9..7846bc600 100644 --- a/R/create.R +++ b/R/create.R @@ -630,8 +630,14 @@ create_initial_conditions <- function(data) { #' #' @param data A list of stan data as created by [create_stan_data()] #' -#' @param init Initial conditions passed to `{rstan}`. Defaults to "random" but -#' can also be a function (as supplied by [create_initial_conditions()]). +#' @param init Initial conditions passed to `{rstan}`. Defaults to "random" +#' (initial values randomly drawn between -2 and 2) but can also be a +#' function (as supplied by [create_initial_conditions()]). +#' +#' @param model Character, name of the model for which arguments are +#' to be created. +#' @param fixed_param Logical, defaults to `FALSE`. Should arguments be +#' created to sample from fixed parameters (used by simulation functions). #' #' @param verbose Logical, defaults to `FALSE`. Should verbose progress #' messages be returned. @@ -650,7 +656,28 @@ create_initial_conditions <- function(data) { create_stan_args <- function(stan = stan_opts(), data = NULL, init = "random", + model = "estimate_infections", + fixed_param = FALSE, verbose = FALSE) { + if (fixed_param) { + if (stan$backend == "rstan") { + stan$algorithm <- "Fixed_param" + } else if (stan$backend == "cmdstanr") { + stan$fixed_param <- TRUE + stan$adapt_delta <- NULL + stan$max_treedepth <- NULL + } + } + ## generate stan model + if (is.null(stan$object)) { + stan$object <- stan_model(stan$backend, model) + stan$backend <- NULL + } + # cmdstanr doesn't have an init = "random" argument + if (is.character(init) && init == "random" && + inherits(stan$object, "CmdStanModel")) { + init <- 2 + } # set up shared default arguments args <- list( data = data, diff --git a/R/dist.R b/R/dist.R index be218d0da..f9144569e 100644 --- a/R/dist.R +++ b/R/dist.R @@ -199,6 +199,7 @@ dist_skel <- function(n, dist = FALSE, cum = TRUE, model, #' @return A stan fit of an interval censored distribution #' @author Sam Abbott #' @export +#' @inheritParams stan_opts #' @examples #' \donttest{ #' # integer adjusted exponential model @@ -221,7 +222,8 @@ dist_skel <- function(n, dist = FALSE, cum = TRUE, model, #' ) #' } dist_fit <- function(values = NULL, samples = 1000, cores = 1, - chains = 2, dist = "exp", verbose = FALSE) { + chains = 2, dist = "exp", verbose = FALSE, + backend = "rstan") { if (samples < 1000) { samples <- 1000 warning(sprintf("%s %s", "`samples` must be at least 1000.", @@ -244,7 +246,7 @@ dist_fit <- function(values = NULL, samples = 1000, cores = 1, par_sigma = numeric(0) ) - model <- stanmodels$dist_fit + model <- stan_model(backend, "dist_fit") if (dist == "exp") { data$dist <- 0 @@ -268,16 +270,21 @@ dist_fit <- function(values = NULL, samples = 1000, cores = 1, } # fit model - fit <- rstan::sampling( - model, - data = data, - iter = samples + 1000, - warmup = 1000, - control = list(adapt_delta = adapt_delta), - chains = chains, - cores = cores, - refresh = ifelse(verbose, 50, 0) + args <- create_stan_args( + stan = stan_opts( + model, + samples = samples, + warmup = 1000, + control = list(adapt_delta = adapt_delta), + chains = chains, + cores = cores, + backend = backend + ), + data = data, verbose = verbose, model = "dist_fit" ) + + fit <- fit_model(args, id = "dist_fit") + return(fit) } @@ -533,11 +540,11 @@ bootstrapped_dist_fit <- function(values, dist = "lognormal", out <- list() if (dist == "lognormal") { - out$mean_samples <- sample(rstan::extract(fit)$mu, samples) - out$sd_samples <- sample(rstan::extract(fit)$sigma, samples) + out$mean_samples <- sample(extract(fit)$mu, samples) + out$sd_samples <- sample(extract(fit)$sigma, samples) } else if (dist == "gamma") { - alpha_samples <- sample(rstan::extract(fit)$alpha, samples) - beta_samples <- sample(rstan::extract(fit)$beta, samples) + alpha_samples <- sample(extract(fit)$alpha, samples) + beta_samples <- sample(extract(fit)$beta, samples) out$mean_samples <- alpha_samples / beta_samples out$sd_samples <- sqrt(alpha_samples) / beta_samples } diff --git a/R/estimate_infections.R b/R/estimate_infections.R index 8ec103d19..f52cee981 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -233,7 +233,7 @@ estimate_infections <- function(reported_cases, args$init_fit == "cumulative") { args$init_fit <- init_cumulative_fit(args, warmup = 50, samples = 50, - id = id, verbose = FALSE + id = id, verbose = FALSE, stan$backend ) } args$init <- extract_inits(args$init_fit, @@ -244,14 +244,8 @@ estimate_infections <- function(reported_cases, args$init_fit <- NULL } # Fit model - if (args$method == "sampling") { - fit <- fit_model_with_nuts(args, - future = args$future, - max_execution_time = args$max_execution_time, id = id - ) - } else if (args$method == "vb") { - fit <- fit_model_with_vb(args, id = id) - } + fit <- fit_model(args, id = id) + # Extract parameters of interest from the fit out <- extract_parameter_samples(fit, data, reported_inf_dates = reported_cases$date, @@ -315,28 +309,31 @@ estimate_infections <- function(reported_cases, #' @importFrom futile.logger flog.debug #' @importFrom utils capture.output #' @inheritParams fit_model_with_nuts +#' @inheritParams stan_opts #' @return A stanfit object #' @author Sam Abbott init_cumulative_fit <- function(args, samples = 50, warmup = 50, - id = "init", verbose = FALSE) { + id = "init", verbose = FALSE, + backend = "rstan") { futile.logger::flog.debug( "%s: Fitting to cumulative data to initialise chains", id, name = "EpiNow2.epinow.estimate_infections.fit" ) # copy main run settings and override to use only 100 iterations and a single # chain - initial_args <- list( - object = args$object, - data = args$data, - init = args$init, - iter = samples + warmup, - warmup = warmup, - chains = 1, - cores = 2, - open_progress = FALSE, - show_messages = FALSE, - control = list(adapt_delta = 0.9, max_treedepth = 13), - refresh = ifelse(verbose, 50, -1) + initial_args <- create_stan_args( + stan = stan_opts( + args$object, + samples = samples, + warmup = warmup, + control = list(adapt_delta = 0.9, max_treedepth = 13), + chains = 1, + cores = 2, + backend = backend, + open_progress = FALSE, + show_messages = FALSE + ), + data = args$data, init = args$init ) # change observations to be cumulative in order to protect against noise and # give an approximate fit (though for Rt constrained to be > 1) @@ -345,12 +342,12 @@ init_cumulative_fit <- function(args, samples = 50, warmup = 50, # initial fit if (verbose) { - fit <- do.call(rstan::sampling, initial_args) + fit <- fit_model(initial_args, id = "init_cumulative") } else { out <- tempfile(tmpdir = tempdir(check = TRUE)) capture.output( { - fit <- do.call(rstan::sampling, initial_args) + fit <- fit_model(initial_args, id = "init_cumulative") }, type = c("output", "message"), split = FALSE, @@ -415,10 +412,16 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf, fit_chain <- function(chain, stan_args, max_time, catch = FALSE) { stan_args$chain_id <- chain + if (inherits(stan_args$object, "stanmodel")) { + sample_func <- rstan::sampling + } else if (inherits(stan_args$object, "CmdStanModel")) { + sample_func <- stan_args$object$sample + stan_args$object <- NULL + } if (catch) { fit <- tryCatch( withCallingHandlers( - R.utils::withTimeout(do.call(rstan::sampling, stan_args), + R.utils::withTimeout(do.call(sample_func, stan_args), timeout = max_time, onTimeout = "silent" ), @@ -441,16 +444,17 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf, } ) } else { - fit <- R.utils::withTimeout(do.call(rstan::sampling, stan_args), + fit <- R.utils::withTimeout(do.call(sample_func, stan_args), timeout = max_time, onTimeout = "silent" ) } - if (is.null(fit) || !is.array(fit)) { - return(NULL) - } else { + if ((inherits(fit, "stanfit") && fit@mode != 2L) || + inherits(fit, "CmdStanMCMC")) { return(fit) + } else { + return(NULL) } } @@ -494,7 +498,7 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf, fit <- rstan::sflist2stanfit(fit) } } else { - fit <- fit_chain(1, + fit <- fit_chain(seq_len(args$chains), stan_args = args, max_time = max_execution_time, catch = !id %in% c("estimate_infections", "epinow") ) @@ -542,7 +546,13 @@ fit_model_with_vb <- function(args, future = FALSE, id = "stan") { } fit_vb <- function(stan_args) { - fit <- do.call(rstan::vb, stan_args) + if (inherits(stan_args$object, "stanmodel")) { + sample_func <- rstan::vb + } else if (inherits(stan_args$object, "CmdStanModel")) { + sample_func <- stan_args$object$variational + stan_args$object <- NULL + } + fit <- do.call(sample_func, stan_args) if (length(names(fit)) == 0) { return(NULL) diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index b29e3d34c..f0511337b 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -50,7 +50,7 @@ #' @param verbose Logical, should model fitting progress be returned. Defaults #' to [interactive()]. #' -#' @param ... Additional parameters to pass to [rstan::sampling()]. +#' @param ... Additional parameters to pass to [stan_opts()]. #' #' @return A list containing: `predictions` (a `` ordered by date #' with the primary, and secondary observations, and a summary of the model @@ -144,6 +144,7 @@ estimate_secondary <- function(reports, ), truncation = trunc_opts(), obs = obs_opts(), + stan = stan_opts(), burn_in = 14, CrIs = c(0.2, 0.5, 0.9), priors = NULL, @@ -198,15 +199,10 @@ estimate_secondary <- function(reports, c(data, list(estimate_r = 0, fixed = 1, bp_n = 0)) ) # fit - if (is.null(model)) { - model <- stanmodels$estimate_secondary - } - fit <- rstan::sampling(model, - data = data, - init = inits, - refresh = ifelse(verbose, 50, 0), - ... + args <- create_stan_args( + stan = stan, data = data, init = inits, model = "estimate_secondary" ) + fit <- fit_model(args, id = "estimate_secondary") out <- list() out$predictions <- extract_stan_param(fit, "sim_secondary", CrIs = CrIs) @@ -603,12 +599,14 @@ simulate_secondary <- function(data, type = "incidence", family = "poisson", #' @importFrom utils tail #' @importFrom purrr map #' @inheritParams estimate_secondary +#' @inheritParams stan_opts #' @seealso [estimate_secondary()] #' @export forecast_secondary <- function(estimate, primary, primary_variable = "reported_cases", model = NULL, + backend = "rstan", samples = NULL, all_dates = FALSE, CrIs = c(0.2, 0.5, 0.9)) { @@ -640,7 +638,7 @@ forecast_secondary <- function(estimate, updated_primary <- primary ## extract samples from given stanfit object - draws <- rstan::extract(estimate$fit, + draws <- extract_samples(estimate$fit, pars = c( "sim_secondary", "log_lik", "lp__", "secondary" @@ -680,28 +678,26 @@ forecast_secondary <- function(estimate, # combine with data data <- c(data, draws) - # load model - if (is.null(model)) { - model <- stanmodels$simulate_secondary - } - # allocate empty parameters data <- allocate_empty( data, c("frac_obs", "delay_mean", "delay_sd", "rep_phi"), n = data$n ) data$all_dates <- as.integer(all_dates) + ## simulate - sims <- rstan::sampling( - object = model, - data = data, chains = 1, iter = 1, - algorithm = "Fixed_param", - refresh = 0 + args <- create_stan_args( + stan_opts( + model = model, backend = backend, chains = 1, samples = 1, warmup = 1 + ), + data = data, fixed_param = TRUE, model = "simulate_secondary" ) + sims <- fit_model(args, id = "simulate_secondary") + # extract samples and organise dates <- unique(primary_fit$date) - samples <- rstan::extract(sims, "sim_secondary")$sim_secondary + samples <- extract_samples(sims, "sim_secondary")$sim_secondary samples <- as.data.table(samples) colnames(samples) <- c("iterations", "sample", "time", "value") samples <- samples[, c("iterations", "time") := NULL] diff --git a/R/estimate_truncation.R b/R/estimate_truncation.R index 7028e89c9..368b1aa4f 100644 --- a/R/estimate_truncation.R +++ b/R/estimate_truncation.R @@ -147,10 +147,19 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10, max = 10 ), model = NULL, + stan = stan_opts(), CrIs = c(0.2, 0.5, 0.9), weigh_delay_priors = FALSE, verbose = TRUE, ...) { + + if (!is.null(model)) { + lifecycle::deprecate_stop( + "2.0.0", + "estimate_truncation(model)", + "estimate_truncation(stan)" + ) + } # Validate inputs walk(obs, check_reports_valid, model = "estimate_truncation") assert_class(truncation, "dist_spec") @@ -230,7 +239,7 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10, obs_start <- max(nrow(obs) - trunc_max - sum(is.na(obs$`1`)) + 1, 1) obs_dist <- purrr::map_dbl(2:(ncol(obs)), ~ sum(is.na(obs[[.]]))) obs_data <- obs[, -1][, purrr::map(.SD, ~ ifelse(is.na(.), 0, .))] - obs_data <- obs_data[obs_start:.N] + obs_data <- as.matrix(obs_data[obs_start:.N]) # convert to stan list data <- list( @@ -261,23 +270,20 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10, } # fit - if (is.null(model)) { - model <- stanmodels$estimate_truncation - } - fit <- rstan::sampling(model, - data = data, - init = init_fn, - refresh = ifelse(verbose, 50, 0), - ... + args <- create_stan_args( + stan = stan, data = data, init = init_fn, model = "estimate_truncation" ) + fit <- fit_model(args, id = "estimate_truncation") out <- list() # Summarise fit truncation distribution for downstream usage + delay_mean <- extract_stan_param(fit, params = "delay_mean") + delay_sd <- extract_stan_param(fit, params = "delay_sd") out$dist <- dist_spec( - mean = round(rstan::summary(fit, pars = "delay_mean")$summary[1], 3), - mean_sd = round(rstan::summary(fit, pars = "delay_mean")$summary[3], 3), - sd = round(rstan::summary(fit, pars = "delay_sd")$summary[1], 3), - sd_sd = round(rstan::summary(fit, pars = "delay_sd")$summary[3], 3), + mean = round(delay_mean$mean, 3), + mean_sd = round(delay_mean$sd, 3), + sd = round(delay_sd$mean, 3), + sd_sd = round(delay_sd$sd, 3), max = truncation$max ) out$dist$dist <- truncation$dist diff --git a/R/extract.R b/R/extract.R index 20f1a277f..d4b7fbc33 100644 --- a/R/extract.R +++ b/R/extract.R @@ -53,6 +53,74 @@ extract_static_parameter <- function(param, samples) { ) } +#' Extract all samples from a stan fit +#' +#' If the `object` argument is a object, it simply returns the result +#' of [rstan::extract()]. If it is a `` it returns samples +#' in the same format as [rstan::extract()] does for `` objects. +#' @param stan_fit A `` or `` object as returned by +#' [fit_model()]. +#' @param pars Any selection of parameters to extract +#' @param include whether the parameters specified in `pars` should be included +#' (`TRUE`, the default) or excluded (`FALSE`) +#' @return List of data.tables with samples +#' @export +#' +#' @importFrom data.table data.table melt setkey +#' @importFrom rstan extract +extract_samples <- function(stan_fit, pars = NULL, include = TRUE) { + if (inherits(stan_fit, "stanfit")) { + args <- list(object = stan_fit, include = include) + if (!is.null(pars)) args <- c(args, list(pars = pars)) + return(do.call(rstan::extract, args)) + } + if (!inherits(stan_fit, "CmdStanMCMC")) { + stop("stan_fit must be a or object") + } + + # extract sample from stan object + if (!include) { + all_pars <- stan_fit$metadata()$stan_variables + pars <- setdiff(all_pars, pars) + } + samples_df <- data.table::data.table(stan_fit$draws( + variables = pars, format = "df") + ) + # convert to rstan format + samples_df <- suppressWarnings(data.table::melt( + samples_df, id.vars = c(".chain", ".iteration", ".draw") + )) + samples_df <- samples_df[, + index := sub("^.*\\[([0-9,]+)\\]$", "\\1", variable) + ][, + variable := sub("\\[.*$", "", variable) + ] + samples <- split(samples_df, by = "variable") + samples <- purrr::map(samples, \(df) { + permutation <- sample( + seq_len(max(df$.draw)), max(df$.draw), replace = FALSE + ) + df <- df[, new_draw := permutation[.draw]] + setkey(df, new_draw) + max_indices <- strsplit(tail(df$index, 1), split = ",", fixed = TRUE)[[1]] + if (any(grepl("[^0-9]", max_indices))) { + max_indices <- 1 + } else { + max_indices <- as.integer(max_indices) + } + ret <- aperm( + a = array(df$value, dim = c(max_indices, length(permutation))), + perm = c(length(max_indices) + 1, seq_along(max_indices)) + ) + ## permute + dimnames(ret) <- c( + list(iterations = NULL), rep(list(NULL), length(max_indices)) + ) + return(ret) + }) + + return(samples) +} #' Extract Parameter Samples from a Stan Model #' @@ -60,9 +128,7 @@ extract_static_parameter <- function(param, samples) { #' Extracts a custom set of parameters from a stan object and adds #' stratification and dates where appropriate. #' -#' @param stan_fit A fit stan model as returned by [rstan::sampling()]. -#' -#' @param data A list of the data supplied to the [rstan::sampling()] call. +#' @param data A list of the data supplied to the [fit_model()] call. #' #' @param reported_dates A vector of dates to report estimates for. #' @@ -75,6 +141,7 @@ extract_static_parameter <- function(param, samples) { #' @param merge if TRUE, merge samples and data so that parameters can be #' extracted from data. #' +#' @inheritParams extract_samples #' @return A list of ``'s each containing the posterior of a #' parameter #' @author Sam Abbott @@ -84,7 +151,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, reported_inf_dates, drop_length_1 = FALSE, merge = FALSE) { # extract sample from stan object - samples <- rstan::extract(stan_fit) + samples <- extract_samples(stan_fit) ## drop initial length 1 dimensions if requested if (drop_length_1) { @@ -208,6 +275,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, #' @author Sam Abbott #' @inheritParams calc_summary_measures #' @export +#' @importFrom posterior mcse_mean #' @importFrom data.table as.data.table := #' @importFrom rstan summary extract_stan_param <- function(fit, params = NULL, @@ -218,28 +286,37 @@ extract_stan_param <- function(fit, params = NULL, sym_CrIs <- sort(sym_CrIs) CrIs <- round(100 * CrIs, 0) CrIs <- c(paste0("lower_", rev(CrIs)), "median", paste0("upper_", CrIs)) - args <- list(object = fit, probs = sym_CrIs) if (!is.null(params)) { if (length(params) > 1) { var_names <- TRUE } - args <- c(args, pars = params) } else { var_names <- TRUE } - summary <- do.call(rstan::summary, args) - summary <- data.table::as.data.table(summary$summary, - keep.rownames = ifelse(var_names, - "variable", - FALSE + if (inherits(fit, "stanfit")) { # rstan backend + args <- list(object = fit, probs = sym_CrIs) + if (!is.null(params)) args <- c(args, list(pars = params)) + summary <- do.call(rstan::summary, args) + summary <- data.table::as.data.table(summary$summary, + keep.rownames = ifelse(var_names, + "variable", + FALSE + ) ) - ) - cols <- c("mean", "se_mean", "sd", CrIs, "n_eff", "Rhat") + summary <- summary[, c("n_eff", "Rhat") := NULL] + } else if (inherits(fit, "CmdStanMCMC")) { # cmdstanr backend + summary <- fit$summary( + variable = params, + mean, mcse_mean, sd, ~quantile(.x, probs = sym_CrIs) + ) + if (!var_names) summary$variable <- NULL + summary <- data.table::as.data.table(summary) + } + cols <- c("mean", "se_mean", "sd", CrIs) if (var_names) { cols <- c("variable", cols) } colnames(summary) <- cols - summary <- summary[, c("n_eff", "Rhat") := NULL] return(summary) } @@ -281,7 +358,7 @@ extract_inits <- function(fit, current_inits, # extract and generate samples as function init_fun <- function(i) { res <- lapply( - rstan::extract(fit), + extract_samples(fit), function(x) { if (length(dim(x)) == 1) { as.array(x[i]) diff --git a/R/opts.R b/R/opts.R index 1c0604fdf..e5f58c5b3 100644 --- a/R/opts.R +++ b/R/opts.R @@ -500,9 +500,38 @@ obs_opts <- function(family = "negbin", #' Rstan Sampling Options #' +#' @description `r lifecycle::badge("deprecated")` +#' Deprecated; use [stan_sampling_opts()] instead. +#' @inheritParams stan_sampling_opts +#' @return A list of arguments to pass to [rstan::sampling()]. +#' @author Sam Abbott +#' @export +rstan_sampling_opts <- function(cores = getOption("mc.cores", 1L), + warmup = 250, + samples = 2000, + chains = 4, + control = list(), + save_warmup = FALSE, + seed = as.integer(runif(1, 1, 1e8)), + future = FALSE, + max_execution_time = Inf, + ...) { + lifecycle::deprecate_warn( + "2.0.0", "rstan_sampling_opts()", + "stan_sampling_opts()", + "This function will be removed in version 2.1.0." + ) + return(stan_sampling_opts( + cores, warmup, samples, chains, control, save_warmup, seed, future, + max_execution_time, backend = "rstan", ... + )) +} + +#' Stan Sampling Options +#' #' @description `r lifecycle::badge("stable")` -#' Defines a list specifying the arguments passed to -#' [rstan::sampling()]. Custom settings can be supplied which override the +#' Defines a list specifying the arguments passed to either [rstan::sampling()] +#' or [cmdstanr::sample()]. Custom settings can be supplied which override the #' defaults. #' #' @param cores Number of cores to use when executing the chains in parallel, @@ -538,27 +567,30 @@ obs_opts <- function(family = "negbin", #' returned. If less than 2 chains return within the allowed time then #' estimation will fail with an informative error. #' +#' @inheritParams stan_opts +#' #' @param ... Additional parameters to pass to [rstan::sampling()]. #' @importFrom utils modifyList #' @return A list of arguments to pass to [rstan::sampling()]. #' @author Sam Abbott +#' @author Sebastian Funk #' @export #' @examples -#' rstan_sampling_opts(samples = 2000) -rstan_sampling_opts <- function(cores = getOption("mc.cores", 1L), - warmup = 250, - samples = 2000, - chains = 4, - control = list(), - save_warmup = FALSE, - seed = as.integer(runif(1, 1, 1e8)), - future = FALSE, - max_execution_time = Inf, - ...) { +#' stan_sampling_opts(samples = 2000) +stan_sampling_opts <- function(cores = getOption("mc.cores", 1L), + warmup = 250, + samples = 2000, + chains = 4, + control = list(), + save_warmup = FALSE, + seed = as.integer(runif(1, 1, 1e8)), + future = FALSE, + max_execution_time = Inf, + backend = "rstan", + ...) { dot_args <- list(...) + backend <- arg_match(backend, values = c("rstan", "cmdstanr")) opts <- list( - cores = cores, - warmup = warmup, chains = chains, save_warmup = save_warmup, seed = seed, @@ -566,18 +598,58 @@ rstan_sampling_opts <- function(cores = getOption("mc.cores", 1L), max_execution_time = max_execution_time ) control_def <- list(adapt_delta = 0.95, max_treedepth = 15) - opts$control <- modifyList(control_def, control) + control_def <- modifyList(control_def, control) + if (any(c("iter", "iter_sampling") %in% names(dot_args))) { + warning( + "Number of samples should be specified using the `samples` and `warmup`", + "arguments rather than `iter` or `iter_sampliing` which will be ignored." + ) + } dot_args$iter <- NULL - opts$iter <- ceiling(samples / opts$chains) + opts$warmup + dot_args$iter_sampling <- NULL + if (backend == "rstan") { + opts <- c(opts, list( + cores = cores, + warmup = warmup, + control = control_def, + iter = ceiling(samples / opts$chains) + warmup + )) + } else if (backend == "cmdstanr") { + opts <- c(opts, list( + parallel_chains = cores, + iter_warmup = warmup, + iter_sampling = ceiling(samples / opts$chains) + ), control_def) + } opts <- c(opts, dot_args) return(opts) } #' Rstan Variational Bayes Options #' +#' @description `r lifecycle::badge("deprecated")` +#' Deprecated; use [stan_vb_opts()] instead. +#' @inheritParams stan_vb_opts +#' @return A list of arguments to pass to [rstan::vb()]. +#' @author Sam Abbott +#' @export +rstan_vb_opts <- function(samples = 2000, + trials = 10, + iter = 10000, ...) { + lifecycle::deprecate_warn( + "2.0.0", "rstan_vb_opts()", + "stan_vb_opts()", + "This function will be removed in version 2.1.0." + ) + return(stan_vb_opts(samples, trials, iter, ...)) +} + +#' Stan Variational Bayes Options +#' #' @description `r lifecycle::badge("stable")` -#' Defines a list specifying the arguments passed to -#' [rstan::vb()]. Custom settings can be supplied which override the defaults. +#' Defines a list specifying the arguments passed to [rstan::vb()] or +#' [cmdstanr::variational()]. Custom settings can be supplied which override the +#' defaults. #' #' @param samples Numeric, default 2000. Overall number of approximate posterior #' samples. @@ -588,16 +660,19 @@ rstan_sampling_opts <- function(cores = getOption("mc.cores", 1L), #' @param iter Numeric, defaulting to 10000. Number of iterations to use in #' [rstan::vb()]. #' -#' @param ... Additional parameters to pass to [rstan::vb()]. +#' @param ... Additional parameters to pass to [rstan::vb()] or +#' [cmdstanr::variational()], depending on the chosen backend. #' -#' @return A list of arguments to pass to [rstan::vb()]. +#' @return A list of arguments to pass to [rstan::vb()] or +#' [cmdstanr::variational()], depending on the chosen backend. #' @author Sam Abbott +#' @author Sebastian Funk #' @export #' @examples -#' rstan_vb_opts(samples = 1000) -rstan_vb_opts <- function(samples = 2000, - trials = 10, - iter = 10000, ...) { +#' stan_vb_opts(samples = 1000) +stan_vb_opts <- function(samples = 2000, + trials = 10, + iter = 10000, ...) { opts <- list( trials = trials, iter = iter, @@ -609,10 +684,8 @@ rstan_vb_opts <- function(samples = 2000, #' Rstan Options #' -#' @description `r lifecycle::badge("stable")` -#' Defines a list specifying the arguments passed to underlying `rstan` -#' functions via [rstan_sampling_opts()] and [rstan_vb_opts()].Custom settings -#' can be supplied which override the defaults. +#' @description `r lifecycle::badge("deprecated")` +#' Deprecated; specify options in [stan_opts()] instead. #' #' @param object Stan model object. By default uses the compiled package #' default. @@ -627,14 +700,14 @@ rstan_vb_opts <- function(samples = 2000, #' @export #' @inheritParams rstan_sampling_opts #' @seealso [rstan_sampling_opts()] [rstan_vb_opts()] -#' @examples -#' rstan_opts(samples = 1000) -#' -#' # using vb -#' rstan_opts(method = "vb") rstan_opts <- function(object = NULL, samples = 2000, method = "sampling", ...) { + lifecycle::deprecate_warn( + "2.0.0", "rstan_opts()", + "stan_opts()", + "This function will be removed in version 2.1.0." + ) method <- arg_match(method, values = c("sampling", "vb")) # shared everywhere opts if (is.null(object)) { @@ -645,9 +718,13 @@ rstan_opts <- function(object = NULL, method = method ) if (method == "sampling") { - opts <- c(opts, rstan_sampling_opts(samples = samples, ...)) + opts <- c( + opts, stan_sampling_opts(samples = samples, backend = "rstan", ...) + ) } else if (method == "vb") { - opts <- c(opts, rstan_vb_opts(samples = samples, ...)) + opts <- c( + opts, stan_vb_opts(samples = samples, ...) + ) } return(opts) } @@ -656,11 +733,21 @@ rstan_opts <- function(object = NULL, #' #' @description `r lifecycle::badge("stable")` #' Defines a list specifying the arguments passed to underlying stan -#' backend functions via [rstan_sampling_opts()] and [rstan_vb_opts()]. Custom +#' backend functions via [stan_sampling_opts()] and [stan_vb_opts()]. Custom #' settings can be supplied which override the defaults. #' +#' @param object Stan model object. By default uses the compiled package +#' default if using the "rstan" backend, and the default model obtained using +#' [package_model()] if using the "cmdstanr" backend. If wanting alternative +#' options to the default with the "cmdstanr" backend, pass the result of +#' a call to [package_model()] with desired arguments instead. +#' +#' @param method A character string, defaulting to sampling. Currently supports +#' MCMC sampling ("sampling") or approximate posterior sampling via +#' variational inference ("vb"). +#' #' @param backend Character string indicating the backend to use for fitting -#' stan models. Currently only "rstan" is supported. +#' stan models. Supported arguments are "rstan" (default) or "cmdstanr". #' #' @param init_fit `r lifecycle::badge("experimental")` #' Character string or `stanfit` object, defaults to NULL. Should an initial @@ -678,33 +765,57 @@ rstan_opts <- function(object = NULL, #' @param return_fit Logical, defaults to TRUE. Should the fit stan model be #' returned. #' -#' @param ... Additional parameters to pass underlying option functions. +#' @param ... Additional parameters to pass to underlying option functions, +#' [stan_sampling_opts()] or [stan_vb_opts()], depending on the method #' #' @importFrom rlang arg_match -#' @return A `` object of arguments to pass to the appropriate +#' @return A `` object of arguments to pass to the appropriate #' rstan functions. #' @author Sam Abbott +#' @author Sebastian Funk #' @export #' @inheritParams rstan_opts -#' @seealso [rstan_opts()] +#' @seealso [stan_sampling_opts()] [stan_vb_opts()] #' @examples #' # using default of [rstan::sampling()] #' stan_opts(samples = 1000) #' #' # using vb #' stan_opts(method = "vb") -stan_opts <- function(samples = 2000, +stan_opts <- function(object = NULL, + samples = 2000, + method = "sampling", backend = "rstan", init_fit = NULL, return_fit = TRUE, ...) { - backend <- arg_match(backend, values = "rstan") - if (backend == "rstan") { - opts <- rstan_opts( - samples = samples, - ... + method <- arg_match(method, values = c("sampling", "vb")) + backend <- arg_match(backend, values = c("rstan", "cmdstanr")) + if (backend == "cmdstanr" && !requireNamespace("cmdstanr", quietly = TRUE)) { + stop( + "The `cmdstanr` package needs to be installed for using the ", + "\"cmdstanr\" backend." ) } + opts <- list() + if (!is.null(object) && !missing(backend)) { + warning( + "`backend` option will be ignored as a stan model object has been passed." + ) + } else { + opts <- c(opts, list(backend = backend)) + } + opts <- c(opts, list( + object = object, + method = method + )) + if (method == "sampling") { + opts <- c( + opts, stan_sampling_opts(samples = samples, backend = backend, ...) + ) + } else if (method == "vb") { + opts <- c(opts, stan_vb_opts(samples = samples, ...)) + } if (!is.null(init_fit)) { if (is.character(init_fit)) { init_fit <- arg_match(init_fit, values = "cumulative") diff --git a/R/simulate_infections.R b/R/simulate_infections.R index af917fd66..6e98de85b 100644 --- a/R/simulate_infections.R +++ b/R/simulate_infections.R @@ -46,6 +46,7 @@ simulate_infections <- function(...) { #' #' @param verbose Logical defaults to [interactive()]. Should a progress bar #' (from `progressr`) be shown. +#' @inheritParams stan_opts #' @importFrom rstan extract sampling #' @importFrom purrr list_transpose map safely compact #' @importFrom future.apply future_lapply @@ -80,7 +81,7 @@ simulate_infections <- function(...) { #' ) #' #' # update Rt trajectory and simulate new infections using it -#' R <- c(rep(NA_real_, 26), rep(0.5, 10), rep(0.8, 7)) +#' R <- c(rep(NA_real_, 26), rep(0.5, 10), rep(0.8, 14)) #' sims <- forecast_infections(est, R) #' plot(sims) #' @@ -111,6 +112,7 @@ forecast_infections <- function(estimates, model = NULL, samples = NULL, batch_size = 10, + backend = "rstan", verbose = interactive()) { ## check inputs assert_class(estimates, "estimate_infections") @@ -212,9 +214,9 @@ forecast_infections <- function(estimates, ) # Load model - if (is.null(model)) { - model <- stanmodels$simulate_infections - } + stan <- stan_opts( + model = model, backend = backend, chains = 1, samples = 1, warmup = 1 + ) ## set up batch simulation batch_simulate <- function(estimates, draws, model, @@ -231,14 +233,14 @@ forecast_infections <- function(estimates, n = data$n ) - ## simulate - sims <- sampling( - object = model, - data = data, chains = 1, iter = 1, - algorithm = "Fixed_param", - refresh = 0 + args <- create_stan_args( + stan, data = data, fixed_param = TRUE, model = "simulate_infections", + verbose = FALSE ) + ## simulate + sims <- fit_model(args, id = "simulate_infections") + out <- extract_parameter_samples(sims, data, reported_inf_dates = dates, reported_dates = dates[-(1:shift)], @@ -261,12 +263,18 @@ forecast_infections <- function(estimates, safe_batch <- safely(batch_simulate) + if (backend == "cmdstanr") { + lapply_func <- lapply ## future_lapply can't handle cmdstanr + } else { + lapply_func <- function(...) future_lapply(future.seed = TRUE, ...) + } + ## simulate in batches with_progress({ if (verbose) { p <- progressor(along = batches) } - out <- future_lapply(batches, + out <- lapply_func(batches, function(batch) { if (verbose) { p() @@ -276,8 +284,7 @@ forecast_infections <- function(estimates, shift, dates, batch[[1]], batch[[2]] )[[1]] - }, - future.seed = TRUE + } ) }) diff --git a/R/stan.R b/R/stan.R new file mode 100644 index 000000000..86847ac69 --- /dev/null +++ b/R/stan.R @@ -0,0 +1,96 @@ +#' Load and compile the nowcasting model +#' +#' The function has been adapted from a similar function in the epinowcast +#' package (Copyright holder: epinowcast authors, under MIT License). +#' +#' @param model A character string indicating the model to use. One of +#' "estimate_infections" (default), "simulate_infections", "estimate_secondary", +#' "simulate_secondary", "estimate_truncation" or "dist_fit". +#' +#' @param include A character string specifying the path to any stan +#' files to include in the model. If missing the package default is used. +#' +#' @param verbose Logical, defaults to `TRUE`. Should verbose +#' messages be shown. +#' +#' @param ... Additional arguments passed to [cmdstanr::cmdstan_model()]. +#' +#' @importFrom rlang arg_match +#' @return A `cmdstanr` model. +#' @export +package_model <- function(model = "estimate_infections", + include = system.file("stan", package = "EpiNow2"), + verbose = FALSE, + ...) { + model <- arg_match( + model, + c( + "estimate_infections", "simulate_infections", "estimate_secondary", + "simulate_secondary", "estimate_truncation", "dist_fit" + ) + ) + model_file <- system.file( + "stan", paste0(model, ".stan"), + package = "EpiNow2" + ) + if (verbose) { + message(sprintf("Using model %s.", model)) + message(sprintf("include is %s.", toString(include))) + } + + monitor <- suppressMessages + if (verbose) { + monitor <- function(x) { + return(x) + } + } + model <- monitor(cmdstanr::cmdstan_model( + model_file, + include_paths = include, + ... + )) + return(model) +} + +##' Return a stan model object for the appropriate backend +##' +##' @inheritParams stan_opts +##' @inheritParams package_model +##' @return A stan model object (either \code{rstan::stanmodel} or +##' \code{cmdstanr::CmdStanModel}, depending on the backend) +##' @author Sebastian Funk +##' @importFrom rlang arg_match +##' @keywords internal +stan_model <- function(backend = c("rstan", "cmdstanr"), + model = c("estimate_infections", "simulate_infections", + "estimate_secondary", "simulate_secondary", + "estimate_truncation", "dist_fit")) { + backend <- arg_match(backend) + model <- arg_match(model) + if (backend == "cmdstanr") { + object <- package_model(model = model) + } else { + object <- stanmodels[[model]] + } + return(object) +} + +#' Fit a model using the chosen backend. +#' +#' Internal function for dispatch to fitting with NUTS or VB. +#' @inheritParams fit_model_with_nuts +#' @keywords internal +fit_model <- function(args, id = "stan") { + if (args$method == "sampling") { + fit <- fit_model_with_nuts( + args, + future = args$future, + max_execution_time = args$max_execution_time, id = id + ) + } else if (args$method == "vb") { + fit <- fit_model_with_vb(args, id = id) + } else { + stop("args$method must be one of 'sampling' or 'vb'") + } + return(fit) +} diff --git a/R/utilities.R b/R/utilities.R index c253a5704..0f6f79dad 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -447,6 +447,6 @@ globalVariables( "central_lower", "central_upper", "mean_sd", "sd_sd", "average_7_day", "..lowers", "..upper_CrI", "..uppers", "timing", "dataset", "last_confirm", "report_date", "secondary", "id", "conv", "meanlog", "primary", "scaled", - "scaling", "sdlog", "lookup" + "scaling", "sdlog", "lookup", "new_draw", ".draw" ) ) diff --git a/inst/dev/recover-synthetic/eval_rt.R b/inst/dev/recover-synthetic/eval_rt.R index 90af08c50..20ff6215c 100644 --- a/inst/dev/recover-synthetic/eval_rt.R +++ b/inst/dev/recover-synthetic/eval_rt.R @@ -23,13 +23,13 @@ dl <- bind_rows(looic, .id = "model") %>% R_samples <- lapply(synthetic$models, function(x) { if ("R[1]" %in% names(x$fit)) { - rstan::extract(x$fit, "R")$R + extract(x$fit, "R")$R } else { - rstan::extract(x$fit, "gen_R")$gen_R + extract(x$fit, "gen_R")$gen_R } }) inf_samples <- lapply(synthetic$models, function(x) { - rstan::extract(x$fit, "infections")$infections + extract(x$fit, "infections")$infections }) calc_crps <- function(x, truth) { diff --git a/man/create_stan_args.Rd b/man/create_stan_args.Rd index 43688ee30..aff1b64a0 100644 --- a/man/create_stan_args.Rd +++ b/man/create_stan_args.Rd @@ -8,6 +8,8 @@ create_stan_args( stan = stan_opts(), data = NULL, init = "random", + model = "estimate_infections", + fixed_param = FALSE, verbose = FALSE ) } @@ -18,8 +20,15 @@ settings if desired.} \item{data}{A list of stan data as created by \code{\link[=create_stan_data]{create_stan_data()}}} -\item{init}{Initial conditions passed to \code{{rstan}}. Defaults to "random" but -can also be a function (as supplied by \code{\link[=create_initial_conditions]{create_initial_conditions()}}).} +\item{init}{Initial conditions passed to \code{{rstan}}. Defaults to "random" +(initial values randomly drawn between -2 and 2) but can also be a +function (as supplied by \code{\link[=create_initial_conditions]{create_initial_conditions()}}).} + +\item{model}{Character, name of the model for which arguments are +to be created.} + +\item{fixed_param}{Logical, defaults to \code{FALSE}. Should arguments be +created to sample from fixed parameters (used by simulation functions).} \item{verbose}{Logical, defaults to \code{FALSE}. Should verbose progress messages be returned.} diff --git a/man/dist_fit.Rd b/man/dist_fit.Rd index 902f3d383..a61b95e75 100644 --- a/man/dist_fit.Rd +++ b/man/dist_fit.Rd @@ -10,7 +10,8 @@ dist_fit( cores = 1, chains = 2, dist = "exp", - verbose = FALSE + verbose = FALSE, + backend = "rstan" ) } \arguments{ @@ -31,6 +32,9 @@ also supported.} \item{verbose}{Logical, defaults to FALSE. Should verbose progress messages be printed.} + +\item{backend}{Character string indicating the backend to use for fitting +stan models. Supported arguments are "rstan" (default) or "cmdstanr".} } \value{ A stan fit of an interval censored distribution diff --git a/man/estimate_secondary.Rd b/man/estimate_secondary.Rd index 688ff020a..6bdc44b1a 100644 --- a/man/estimate_secondary.Rd +++ b/man/estimate_secondary.Rd @@ -11,6 +11,7 @@ estimate_secondary( 30)), truncation = trunc_opts(), obs = obs_opts(), + stan = stan_opts(), burn_in = 14, CrIs = c(0.2, 0.5, 0.9), priors = NULL, @@ -42,6 +43,10 @@ an approach to estimating truncation from data.} \item{obs}{A list of options as generated by \code{\link[=obs_opts]{obs_opts()}} defining the observation model. Defaults to \code{\link[=obs_opts]{obs_opts()}}.} +\item{stan}{A list of stan options as generated by \code{\link[=stan_opts]{stan_opts()}}. Defaults +to \code{\link[=stan_opts]{stan_opts()}}. Can be used to override \code{data}, \code{init}, and \code{verbose} +settings if desired.} + \item{burn_in}{Integer, defaults to 14 days. The number of data points to use for estimation but not to fit to at the beginning of the time series. This must be less than the number of observations.} @@ -66,7 +71,7 @@ parameters.} \item{verbose}{Logical, should model fitting progress be returned. Defaults to \code{\link[=interactive]{interactive()}}.} -\item{...}{Additional parameters to pass to \code{\link[rstan:stanmodel-method-sampling]{rstan::sampling()}}.} +\item{...}{Additional parameters to pass to \code{\link[=stan_opts]{stan_opts()}}.} } \value{ A list containing: \code{predictions} (a \verb{} ordered by date diff --git a/man/estimate_truncation.Rd b/man/estimate_truncation.Rd index 3ab811610..64b908272 100644 --- a/man/estimate_truncation.Rd +++ b/man/estimate_truncation.Rd @@ -11,6 +11,7 @@ estimate_truncation( trunc_dist = "lognormal", truncation = dist_spec(mean = 0, sd = 0, mean_sd = 1, sd_sd = 1, max = 10), model = NULL, + stan = stan_opts(), CrIs = c(0.2, 0.5, 0.9), weigh_delay_priors = FALSE, verbose = TRUE, @@ -36,6 +37,10 @@ an approach to estimating truncation from data.} \item{model}{A compiled stan model to override the default model. May be useful for package developers or those developing extensions.} +\item{stan}{A list of stan options as generated by \code{\link[=stan_opts]{stan_opts()}}. Defaults +to \code{\link[=stan_opts]{stan_opts()}}. Can be used to override \code{data}, \code{init}, and \code{verbose} +settings if desired.} + \item{CrIs}{Numeric vector of credible intervals to calculate.} \item{weigh_delay_priors}{Logical. If TRUE, all delay distribution priors diff --git a/man/extract_parameter_samples.Rd b/man/extract_parameter_samples.Rd index ef35695b6..808da3670 100644 --- a/man/extract_parameter_samples.Rd +++ b/man/extract_parameter_samples.Rd @@ -14,9 +14,10 @@ extract_parameter_samples( ) } \arguments{ -\item{stan_fit}{A fit stan model as returned by \code{\link[rstan:stanmodel-method-sampling]{rstan::sampling()}}.} +\item{stan_fit}{A \verb{} or \verb{} object as returned by +\code{\link[=fit_model]{fit_model()}}.} -\item{data}{A list of the data supplied to the \code{\link[rstan:stanmodel-method-sampling]{rstan::sampling()}} call.} +\item{data}{A list of the data supplied to the \code{\link[=fit_model]{fit_model()}} call.} \item{reported_dates}{A vector of dates to report estimates for.} diff --git a/man/extract_samples.Rd b/man/extract_samples.Rd new file mode 100644 index 000000000..dc6f4653f --- /dev/null +++ b/man/extract_samples.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract.R +\name{extract_samples} +\alias{extract_samples} +\title{Extract all samples from a stan fit} +\usage{ +extract_samples(stan_fit, pars = NULL, include = TRUE) +} +\arguments{ +\item{stan_fit}{A \verb{} or \verb{} object as returned by +\code{\link[=fit_model]{fit_model()}}.} + +\item{pars}{Any selection of parameters to extract} + +\item{include}{whether the parameters specified in \code{pars} should be included +(\code{TRUE}, the default) or excluded (\code{FALSE})} +} +\value{ +List of data.tables with samples +} +\description{ +If the \code{object} argument is a \if{html}{\out{}} object, it simply returns the result +of \code{\link[rstan:stanfit-method-extract]{rstan::extract()}}. If it is a \verb{} it returns samples +in the same format as \code{\link[rstan:stanfit-method-extract]{rstan::extract()}} does for \verb{} objects. +} diff --git a/man/fit_model.Rd b/man/fit_model.Rd new file mode 100644 index 000000000..0f9cf9e93 --- /dev/null +++ b/man/fit_model.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/stan.R +\name{fit_model} +\alias{fit_model} +\title{Fit a model using the chosen backend.} +\usage{ +fit_model(args, id = "stan") +} +\arguments{ +\item{args}{List of stan arguments.} + +\item{id}{A character string used to assign logging information on error. +Used by \code{\link[=regional_epinow]{regional_epinow()}} to assign errors to regions. Alter the default to +run with error catching.} +} +\description{ +Internal function for dispatch to fitting with NUTS or VB. +} +\keyword{internal} diff --git a/man/forecast_infections.Rd b/man/forecast_infections.Rd index ec90e946c..822896ded 100644 --- a/man/forecast_infections.Rd +++ b/man/forecast_infections.Rd @@ -11,6 +11,7 @@ forecast_infections( model = NULL, samples = NULL, batch_size = 10, + backend = "rstan", verbose = interactive() ) } @@ -34,6 +35,9 @@ default is to use all samples in the \code{estimates} input.} simulate. May decrease run times due to reduced IO costs but this is still being evaluated. If set to NULL then all simulations are done at once.} +\item{backend}{Character string indicating the backend to use for fitting +stan models. Supported arguments are "rstan" (default) or "cmdstanr".} + \item{verbose}{Logical defaults to \code{\link[=interactive]{interactive()}}. Should a progress bar (from \code{progressr}) be shown.} } @@ -68,7 +72,7 @@ est <- estimate_infections(reported_cases, ) # update Rt trajectory and simulate new infections using it -R <- c(rep(NA_real_, 26), rep(0.5, 10), rep(0.8, 7)) +R <- c(rep(NA_real_, 26), rep(0.5, 10), rep(0.8, 14)) sims <- forecast_infections(est, R) plot(sims) diff --git a/man/forecast_secondary.Rd b/man/forecast_secondary.Rd index 256445fb3..affbd86f0 100644 --- a/man/forecast_secondary.Rd +++ b/man/forecast_secondary.Rd @@ -9,6 +9,7 @@ forecast_secondary( primary, primary_variable = "reported_cases", model = NULL, + backend = "rstan", samples = NULL, all_dates = FALSE, CrIs = c(0.2, 0.5, 0.9) @@ -32,6 +33,9 @@ defaulting to "reported_cases". Only used when primary is of class \item{model}{A compiled stan model as returned by \code{\link[rstan:stan_model]{rstan::stan_model()}}.} +\item{backend}{Character string indicating the backend to use for fitting +stan models. Supported arguments are "rstan" (default) or "cmdstanr".} + \item{samples}{Numeric, number of posterior samples to simulate from. The default is to use all samples in the \code{primary} input when present. If not present the default is to use 1000 samples.} diff --git a/man/init_cumulative_fit.Rd b/man/init_cumulative_fit.Rd index 5fea3e2f2..adc6b2927 100644 --- a/man/init_cumulative_fit.Rd +++ b/man/init_cumulative_fit.Rd @@ -9,7 +9,8 @@ init_cumulative_fit( samples = 50, warmup = 50, id = "init", - verbose = FALSE + verbose = FALSE, + backend = "rstan" ) } \arguments{ @@ -25,6 +26,9 @@ run with error catching.} \item{verbose}{Logical, should fitting progress be returned. Defaults to \code{FALSE}.} + +\item{backend}{Character string indicating the backend to use for fitting +stan models. Supported arguments are "rstan" (default) or "cmdstanr".} } \value{ A stanfit object diff --git a/man/package_model.Rd b/man/package_model.Rd new file mode 100644 index 000000000..5a377b0fd --- /dev/null +++ b/man/package_model.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/stan.R +\name{package_model} +\alias{package_model} +\title{Load and compile the nowcasting model} +\usage{ +package_model( + model = "estimate_infections", + include = system.file("stan", package = "EpiNow2"), + verbose = FALSE, + ... +) +} +\arguments{ +\item{model}{A character string indicating the model to use. One of +"estimate_infections" (default), "simulate_infections", "estimate_secondary", +"simulate_secondary", "estimate_truncation" or "dist_fit".} + +\item{include}{A character string specifying the path to any stan +files to include in the model. If missing the package default is used.} + +\item{verbose}{Logical, defaults to \code{TRUE}. Should verbose +messages be shown.} + +\item{...}{Additional arguments passed to \code{\link[cmdstanr:cmdstan_model]{cmdstanr::cmdstan_model()}}.} +} +\value{ +A \code{cmdstanr} model. +} +\description{ +The function has been adapted from a similar function in the epinowcast +package (Copyright holder: epinowcast authors, under MIT License). +} diff --git a/man/rstan_opts.Rd b/man/rstan_opts.Rd index fa9f76cbb..b60a77645 100644 --- a/man/rstan_opts.Rd +++ b/man/rstan_opts.Rd @@ -22,16 +22,8 @@ When using multiple chains iterations per chain is samples / chains.} A list of arguments to pass to the appropriate rstan functions. } \description{ -\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} -Defines a list specifying the arguments passed to underlying \code{rstan} -functions via \code{\link[=rstan_sampling_opts]{rstan_sampling_opts()}} and \code{\link[=rstan_vb_opts]{rstan_vb_opts()}}.Custom settings -can be supplied which override the defaults. -} -\examples{ -rstan_opts(samples = 1000) - -# using vb -rstan_opts(method = "vb") +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} +Deprecated; specify options in \code{\link[=stan_opts]{stan_opts()}} instead. } \seealso{ \code{\link[=rstan_sampling_opts]{rstan_sampling_opts()}} \code{\link[=rstan_vb_opts]{rstan_vb_opts()}} diff --git a/man/rstan_sampling_opts.Rd b/man/rstan_sampling_opts.Rd index 62a00ffcb..7d1c92465 100644 --- a/man/rstan_sampling_opts.Rd +++ b/man/rstan_sampling_opts.Rd @@ -57,13 +57,8 @@ estimation will fail with an informative error.} A list of arguments to pass to \code{\link[rstan:stanmodel-method-sampling]{rstan::sampling()}}. } \description{ -\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} -Defines a list specifying the arguments passed to -\code{\link[rstan:stanmodel-method-sampling]{rstan::sampling()}}. Custom settings can be supplied which override the -defaults. -} -\examples{ -rstan_sampling_opts(samples = 2000) +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} +Deprecated; use \code{\link[=stan_sampling_opts]{stan_sampling_opts()}} instead. } \author{ Sam Abbott diff --git a/man/rstan_vb_opts.Rd b/man/rstan_vb_opts.Rd index bafe58b24..7f353928a 100644 --- a/man/rstan_vb_opts.Rd +++ b/man/rstan_vb_opts.Rd @@ -16,18 +16,15 @@ rstan::vb()] before failing.} \item{iter}{Numeric, defaulting to 10000. Number of iterations to use in \code{\link[rstan:stanmodel-method-vb]{rstan::vb()}}.} -\item{...}{Additional parameters to pass to \code{\link[rstan:stanmodel-method-vb]{rstan::vb()}}.} +\item{...}{Additional parameters to pass to \code{\link[rstan:stanmodel-method-vb]{rstan::vb()}} or +\code{\link[cmdstanr:model-method-variational]{cmdstanr::variational()}}, depending on the chosen backend.} } \value{ A list of arguments to pass to \code{\link[rstan:stanmodel-method-vb]{rstan::vb()}}. } \description{ -\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} -Defines a list specifying the arguments passed to -\code{\link[rstan:stanmodel-method-vb]{rstan::vb()}}. Custom settings can be supplied which override the defaults. -} -\examples{ -rstan_vb_opts(samples = 1000) +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} +Deprecated; use \code{\link[=stan_vb_opts]{stan_vb_opts()}} instead. } \author{ Sam Abbott diff --git a/man/stan_model.Rd b/man/stan_model.Rd new file mode 100644 index 000000000..edd371a96 --- /dev/null +++ b/man/stan_model.Rd @@ -0,0 +1,31 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/stan.R +\name{stan_model} +\alias{stan_model} +\title{Return a stan model object for the appropriate backend} +\usage{ +stan_model( + backend = c("rstan", "cmdstanr"), + model = c("estimate_infections", "simulate_infections", "estimate_secondary", + "simulate_secondary", "estimate_truncation", "dist_fit") +) +} +\arguments{ +\item{backend}{Character string indicating the backend to use for fitting +stan models. Supported arguments are "rstan" (default) or "cmdstanr".} + +\item{model}{A character string indicating the model to use. One of +"estimate_infections" (default), "simulate_infections", "estimate_secondary", +"simulate_secondary", "estimate_truncation" or "dist_fit".} +} +\value{ +A stan model object (either \code{rstan::stanmodel} or +\code{cmdstanr::CmdStanModel}, depending on the backend) +} +\description{ +Return a stan model object for the appropriate backend +} +\author{ +Sebastian Funk +} +\keyword{internal} diff --git a/man/stan_opts.Rd b/man/stan_opts.Rd index be49eb8d2..ea0ed150c 100644 --- a/man/stan_opts.Rd +++ b/man/stan_opts.Rd @@ -5,7 +5,9 @@ \title{Stan Options} \usage{ stan_opts( + object = NULL, samples = 2000, + method = "sampling", backend = "rstan", init_fit = NULL, return_fit = TRUE, @@ -13,11 +15,21 @@ stan_opts( ) } \arguments{ +\item{object}{Stan model object. By default uses the compiled package +default if using the "rstan" backend, and the default model obtained using +\code{\link[=package_model]{package_model()}} if using the "cmdstanr" backend. If wanting alternative +options to the default with the "cmdstanr" backend, pass the result of +a call to \code{\link[=package_model]{package_model()}} with desired arguments instead.} + \item{samples}{Numeric, default 2000. Overall number of posterior samples. When using multiple chains iterations per chain is samples / chains.} +\item{method}{A character string, defaulting to sampling. Currently supports +MCMC sampling ("sampling") or approximate posterior sampling via +variational inference ("vb").} + \item{backend}{Character string indicating the backend to use for fitting -stan models. Currently only "rstan" is supported.} +stan models. Supported arguments are "rstan" (default) or "cmdstanr".} \item{init_fit}{\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} Character string or \code{stanfit} object, defaults to NULL. Should an initial @@ -35,16 +47,17 @@ James Scott.} \item{return_fit}{Logical, defaults to TRUE. Should the fit stan model be returned.} -\item{...}{Additional parameters to pass underlying option functions.} +\item{...}{Additional parameters to pass to underlying option functions, +\code{\link[=stan_sampling_opts]{stan_sampling_opts()}} or \code{\link[=stan_vb_opts]{stan_vb_opts()}}, depending on the method} } \value{ -A \verb{} object of arguments to pass to the appropriate +A \verb{} object of arguments to pass to the appropriate rstan functions. } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} Defines a list specifying the arguments passed to underlying stan -backend functions via \code{\link[=rstan_sampling_opts]{rstan_sampling_opts()}} and \code{\link[=rstan_vb_opts]{rstan_vb_opts()}}. Custom +backend functions via \code{\link[=stan_sampling_opts]{stan_sampling_opts()}} and \code{\link[=stan_vb_opts]{stan_vb_opts()}}. Custom settings can be supplied which override the defaults. } \examples{ @@ -55,8 +68,10 @@ stan_opts(samples = 1000) stan_opts(method = "vb") } \seealso{ -\code{\link[=rstan_opts]{rstan_opts()}} +\code{\link[=stan_sampling_opts]{stan_sampling_opts()}} \code{\link[=stan_vb_opts]{stan_vb_opts()}} } \author{ Sam Abbott + +Sebastian Funk } diff --git a/man/stan_sampling_opts.Rd b/man/stan_sampling_opts.Rd new file mode 100644 index 000000000..d67c6e0b3 --- /dev/null +++ b/man/stan_sampling_opts.Rd @@ -0,0 +1,76 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/opts.R +\name{stan_sampling_opts} +\alias{stan_sampling_opts} +\title{Stan Sampling Options} +\usage{ +stan_sampling_opts( + cores = getOption("mc.cores", 1L), + warmup = 250, + samples = 2000, + chains = 4, + control = list(), + save_warmup = FALSE, + seed = as.integer(runif(1, 1, 1e+08)), + future = FALSE, + max_execution_time = Inf, + backend = "rstan", + ... +) +} +\arguments{ +\item{cores}{Number of cores to use when executing the chains in parallel, +which defaults to 1 but it is recommended to set the mc.cores option to be +as many processors as the hardware and RAM allow (up to the number of +chains).} + +\item{warmup}{Numeric, defaults to 250. Number of warmup samples per chain.} + +\item{samples}{Numeric, default 2000. Overall number of posterior samples. +When using multiple chains iterations per chain is samples / chains.} + +\item{chains}{Numeric, defaults to 4. Number of MCMC chains to use.} + +\item{control}{List, defaults to empty. control parameters to pass to +underlying \code{rstan} function. By default \code{adapt_delta = 0.95} and +\code{max_treedepth = 15} though these settings can be overwritten.} + +\item{save_warmup}{Logical, defaults to FALSE. Should warmup progress be +saved.} + +\item{seed}{Numeric, defaults uniform random number between 1 and 1e8. Seed +of sampling process.} + +\item{future}{Logical, defaults to \code{FALSE}. Should stan chains be run in +parallel using \code{future}. This allows users to have chains fail gracefully +(i.e when combined with \code{max_execution_time}). Should be combined with a +call to \code{\link[future:plan]{future::plan()}}.} + +\item{max_execution_time}{Numeric, defaults to Inf (seconds). If set wil +kill off processing of each chain if not finished within the specified +timeout. When more than 2 chains finish successfully estimates will still be +returned. If less than 2 chains return within the allowed time then +estimation will fail with an informative error.} + +\item{backend}{Character string indicating the backend to use for fitting +stan models. Supported arguments are "rstan" (default) or "cmdstanr".} + +\item{...}{Additional parameters to pass to \code{\link[rstan:stanmodel-method-sampling]{rstan::sampling()}}.} +} +\value{ +A list of arguments to pass to \code{\link[rstan:stanmodel-method-sampling]{rstan::sampling()}}. +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} +Defines a list specifying the arguments passed to either \code{\link[rstan:stanmodel-method-sampling]{rstan::sampling()}} +or \code{\link[cmdstanr:model-method-sample]{cmdstanr::sample()}}. Custom settings can be supplied which override the +defaults. +} +\examples{ +stan_sampling_opts(samples = 2000) +} +\author{ +Sam Abbott + +Sebastian Funk +} diff --git a/man/stan_vb_opts.Rd b/man/stan_vb_opts.Rd new file mode 100644 index 000000000..cb4936929 --- /dev/null +++ b/man/stan_vb_opts.Rd @@ -0,0 +1,39 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/opts.R +\name{stan_vb_opts} +\alias{stan_vb_opts} +\title{Stan Variational Bayes Options} +\usage{ +stan_vb_opts(samples = 2000, trials = 10, iter = 10000, ...) +} +\arguments{ +\item{samples}{Numeric, default 2000. Overall number of approximate posterior +samples.} + +\item{trials}{Numeric, defaults to 10. Number of attempts to use +rstan::vb()] before failing.} + +\item{iter}{Numeric, defaulting to 10000. Number of iterations to use in +\code{\link[rstan:stanmodel-method-vb]{rstan::vb()}}.} + +\item{...}{Additional parameters to pass to \code{\link[rstan:stanmodel-method-vb]{rstan::vb()}} or +\code{\link[cmdstanr:model-method-variational]{cmdstanr::variational()}}, depending on the chosen backend.} +} +\value{ +A list of arguments to pass to \code{\link[rstan:stanmodel-method-vb]{rstan::vb()}} or +\code{\link[cmdstanr:model-method-variational]{cmdstanr::variational()}}, depending on the chosen backend. +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} +Defines a list specifying the arguments passed to \code{\link[rstan:stanmodel-method-vb]{rstan::vb()}} or +\code{\link[cmdstanr:model-method-variational]{cmdstanr::variational()}}. Custom settings can be supplied which override the +defaults. +} +\examples{ +stan_vb_opts(samples = 1000) +} +\author{ +Sam Abbott + +Sebastian Funk +} diff --git a/tests/testthat/test-create_stan_args.R b/tests/testthat/test-create_stan_args.R index 2a54af1c6..c76aad09b 100644 --- a/tests/testthat/test-create_stan_args.R +++ b/tests/testthat/test-create_stan_args.R @@ -1,8 +1,7 @@ test_that("create_stan_args returns the expected defaults when the exact method is used", { expect_equal(names(create_stan_args()), c( - "data", "init", "refresh", "object", "method", - "cores", "warmup", "chains", "save_warmup", "seed", - "future", "max_execution_time", "control", "iter" + "data", "init", "refresh", "object", "method", "chains", "save_warmup", + "seed", "future", "max_execution_time", "cores", "warmup", "control", "iter" )) }) diff --git a/tests/testthat/test-epinow.R b/tests/testthat/test-epinow.R index 8dc6bd340..68810f5ce 100644 --- a/tests/testthat/test-epinow.R +++ b/tests/testthat/test-epinow.R @@ -38,6 +38,35 @@ test_that("epinow produces expected output when run with default settings", { expect_equal(names(out$plots), c("infections", "reports", "R", "growth_rate", "summary")) }) +test_that("epinow produces expected output when run with the + cmdstanr backend", { + skip_on_os("windows") + output <- capture.output(suppressMessages(suppressWarnings( + out <- epinow( + reported_cases = reported_cases, + generation_time = generation_time_opts(example_generation_time), + delays = delay_opts(example_incubation_period + reporting_delay), + stan = stan_opts( + samples = 25, warmup = 25, + cores = 1, chains = 2, + control = list(adapt_delta = 0.8), + backend = "cmdstanr" + ), + logs = NULL, verbose = FALSE + ) + ))) + + expect_equal(names(out), expected_out) + df_non_zero(out$estimates$samples) + df_non_zero(out$estimates$summarised) + df_non_zero(out$estimated_reported_cases$samples) + df_non_zero(out$estimated_reported_cases$summarised) + df_non_zero(out$summary) + expect_equal( + names(out$plots), c("infections", "reports", "R", "growth_rate", "summary") + ) +}) + test_that("epinow runs without error when saving to disk", { expect_null(suppressWarnings(epinow( reported_cases = reported_cases, diff --git a/tests/testthat/test-estimate_secondary.R b/tests/testthat/test-estimate_secondary.R index a887c2381..5485f309b 100644 --- a/tests/testthat/test-estimate_secondary.R +++ b/tests/testthat/test-estimate_secondary.R @@ -5,14 +5,16 @@ skip_on_cran() # make some example secondary incidence data cases <- example_confirmed cases <- as.data.table(cases)[, primary := confirm] + +inc_cases <- copy(cases) # Assume that only 40 percent of cases are reported -cases[, scaling := 0.4] +inc_cases[, scaling := 0.4] # Parameters of the assumed log normal delay distribution -cases[, meanlog := 1.8][, sdlog := 0.5] +inc_cases[, meanlog := 1.8][, sdlog := 0.5] # Simulate secondary cases -cases <- simulate_secondary(cases, type = "incidence") -cases[ +inc_cases <- simulate_secondary(inc_cases, type = "incidence") +inc_cases[ , c("confirm", "scaling", "meanlog", "sdlog", "index", "scaled", "conv") := NULL @@ -20,7 +22,7 @@ cases[ # # fit model to example data specifying a weak prior for fraction reported # with a secondary case -inc <- estimate_secondary(cases[1:60], +inc <- estimate_secondary(inc_cases[1:60], obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), verbose = FALSE ) @@ -36,18 +38,17 @@ inc_posterior <- inc$posterior[variable %in% params] #### Prevalence data example #### # make some example prevalence data -cases <- example_confirmed -cases <- as.data.table(cases)[, primary := confirm] +prev_cases <- copy(cases) # Assume that only 30 percent of cases are reported -cases[, scaling := 0.3] +prev_cases[, scaling := 0.3] # Parameters of the assumed log normal delay distribution -cases[, meanlog := 1.6][, sdlog := 0.8] +prev_cases[, meanlog := 1.6][, sdlog := 0.8] # Simulate secondary cases -cases <- simulate_secondary(cases, type = "prevalence") +prev_cases <- simulate_secondary(prev_cases, type = "prevalence") # fit model to example prevalence data -prev <- estimate_secondary(cases[1:100], +prev <- estimate_secondary(prev_cases[1:100], secondary = secondary_opts(type = "prevalence"), obs = obs_opts( week_effect = FALSE, @@ -92,12 +93,45 @@ test_that("estimate_secondary can recover simulated parameters", { ) }) +test_that("estimate_secondary can recover simulated parameters with the + cmdstanr backend", { + skip_on_os("windows") + output <- capture.output(suppressMessages(suppressWarnings( + inc_cmdstanr <- estimate_secondary(inc_cases[1:60], + obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + verbose = FALSE, stan = stan_opts(backend = "cmdstanr") + ) + ))) + inc_posterior_cmdstanr <- inc_cmdstanr$posterior[variable %in% params] + expect_equal( + inc_posterior_cmdstanr[, mean], c(1.8, 0.5, 0.4), + tolerance = 0.1 + ) + expect_equal( + inc_posterior_cmdstanr[, median], c(1.8, 0.5, 0.4), + tolerance = 0.1 + ) +}) + test_that("forecast_secondary can return values from simulated data and plot them", { - inc_preds <- forecast_secondary(inc, cases[seq(61, .N)][, value := primary]) + inc_preds <- forecast_secondary( + inc, inc_cases[seq(61, .N)][, value := primary] + ) expect_equal(names(inc_preds), c("samples", "forecast", "predictions")) # validation plot of observations vs estimates - expect_error(plot(inc_preds, new_obs = cases, from = "2020-05-01"), NA) + expect_error(plot(inc_preds, new_obs = inc_cases, from = "2020-05-01"), NA) +}) + +test_that("forecast_secondary can return values from simulated data when using + the cmdstanr backend", { + skip_on_os("windows") + capture.output(suppressMessages(suppressWarnings( + inc_preds <- forecast_secondary( + inc, inc_cases[seq(61, .N)][, value := primary], backend = "cmdstanr" + ) + ))) + expect_equal(names(inc_preds), c("samples", "forecast", "predictions")) }) test_that("estimate_secondary works with weigh_delay_priors = TRUE", { @@ -105,7 +139,7 @@ test_that("estimate_secondary works with weigh_delay_priors = TRUE", { mean = 2.5, mean_sd = 0.5, sd = 0.47, sd_sd = 0.25, max = 30 ) inc_weigh <- estimate_secondary( - cases[1:60], delays = delay_opts(delays), + inc_cases[1:60], delays = delay_opts(delays), obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), weigh_delay_priors = TRUE, verbose = FALSE ) diff --git a/tests/testthat/test-estimate_truncation.R b/tests/testthat/test-estimate_truncation.R index dad8f4b12..c1fd03e8e 100644 --- a/tests/testthat/test-estimate_truncation.R +++ b/tests/testthat/test-estimate_truncation.R @@ -54,6 +54,23 @@ test_that("estimate_truncation can return values from simulated data and plot expect_error(plot(est), NA) }) +test_that("estimate_truncation can return values from simulated data with the + cmdstanr backend", { + # fit model to example data + skip_on_os("windows") + output <- capture.output(suppressMessages(suppressWarnings( + est <- estimate_truncation(example_data, + verbose = FALSE, chains = 2, iter = 1000, warmup = 250, + stan = stan_opts(backend = "cmdstanr") + )))) + expect_equal( + names(est), + c("dist", "obs", "last_obs", "cmf", "data", "fit") + ) + expect_s3_class(est$dist, "dist_spec") + expect_error(plot(est), NA) +}) + test_that("deprecated arguments are recognised", { options(warn = 2) expect_error(estimate_truncation(example_data, diff --git a/tests/testthat/test-simulate_infections.R b/tests/testthat/test-simulate_infections.R index 4d626fa8a..3221e2721 100644 --- a/tests/testthat/test-simulate_infections.R +++ b/tests/testthat/test-simulate_infections.R @@ -20,6 +20,15 @@ test_that("forecast_infections works to simulate a passed in estimate_infections expect_equal(names(sims), c("samples", "summarised", "observations")) }) +test_that("forecast_infections works to simulate a passed in estimate_infections + object when using the cmdstanr backend", { + skip_on_os("windows") + output <- capture.output(suppressMessages(suppressWarnings( + sims <- forecast_infections(out, backend = "cmdstanr") + ))) + expect_equal(names(sims), c("samples", "summarised", "observations")) +}) + test_that("forecast_infections works to simulate a passed in estimate_infections object with an adjusted Rt", { R <- c(rep(NA_real_, 40), rep(0.5, 17)) sims <- forecast_infections(out, R)