From 9d733d2d91018661e10c80e54a543b4b07784d80 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Fri, 6 Dec 2024 14:12:47 +0000 Subject: [PATCH] Change all priors to use `` (#871) * new parameter interface in stan code * adapt R code to new param interface * render docs * update tests * update examples and other code snippets * add news item * add progressr to lintr workflow * switch benchmarks back to previous syntax otherwise they won't work on main * change benchmark back it won't work anyway as the R code has changed too much * bumping all vignettes * add reviewer Co-authored-by: Sam Abbott * cherry pick vignettes from main * make all priors consistent with previous versions --------- Co-authored-by: Sam --- .../workflows/lint-only-changed-files.yaml | 1 + NAMESPACE | 2 + NEWS.md | 5 +- R/create.R | 166 +++++++++++++----- R/dist_spec.R | 51 ++++++ R/epinow.R | 2 +- R/estimate_infections.R | 2 +- R/estimate_secondary.R | 15 +- R/extract.R | 13 +- R/opts.R | 140 ++++++++++----- R/regional_epinow.R | 2 +- R/simulate_infections.R | 31 ++-- R/simulate_secondary.R | 24 ++- data-raw/estimate-infections.R | 4 +- inst/dev/benchmark-functions.R | 2 +- inst/dev/recover-synthetic/rt.R | 12 +- .../stan/data/estimate_infections_params.stan | 4 + inst/stan/data/estimate_secondary_params.stan | 2 + inst/stan/data/gaussian_process.stan | 2 - inst/stan/data/observation_model.stan | 4 - inst/stan/data/params.stan | 13 ++ inst/stan/data/rt.stan | 2 - .../data/simulation_observation_model.stan | 2 - inst/stan/estimate_infections.stan | 61 ++++--- inst/stan/estimate_secondary.stan | 46 +++-- inst/stan/functions/gaussian_process.stan | 7 +- inst/stan/functions/observation_model.stan | 22 +-- inst/stan/functions/params.stan | 53 ++++++ inst/stan/functions/rt.stan | 21 +-- inst/stan/simulate_infections.stan | 121 +++++++------ inst/stan/simulate_secondary.stan | 102 ++++++----- man/create_obs_model.Rd | 2 +- man/create_stan_params.Rd | 22 +++ man/epinow.Rd | 2 +- man/equals-.dist_spec.Rd | 27 +++ man/estimate_infections.Rd | 2 +- man/estimate_secondary.Rd | 4 +- man/forecast_infections.Rd | 4 +- man/gp_opts.Rd | 26 +-- man/obs_opts.Rd | 30 ++-- man/regional_epinow.Rd | 2 +- man/rt_opts.Rd | 11 +- tests/testthat/test-create_gp_data.R | 2 +- tests/testthat/test-create_obs_model.R | 40 +---- tests/testthat/test-create_rt_date.R | 5 - tests/testthat/test-create_stan_params.R | 53 ++++++ tests/testthat/test-estimate_secondary.R | 24 +-- tests/testthat/test-gp_opts.R | 2 +- tests/testthat/test-obs_opts.R | 2 +- tests/testthat/test-rt_opts.R | 23 ++- tests/testthat/test-simulate-infections.R | 4 +- tests/testthat/test-simulate-secondary.R | 4 +- tests/testthat/test-stan-rt.R | 22 +-- vignettes/EpiNow2.Rmd.orig | 4 +- vignettes/epinow.Rmd.orig | 2 +- .../estimate_infections_options.Rmd.orig | 2 +- .../estimate_infections_workflow.Rmd.orig | 4 +- 57 files changed, 809 insertions(+), 450 deletions(-) create mode 100644 inst/stan/data/estimate_infections_params.stan create mode 100644 inst/stan/data/estimate_secondary_params.stan create mode 100644 inst/stan/data/params.stan create mode 100644 inst/stan/functions/params.stan create mode 100644 man/create_stan_params.Rd create mode 100644 man/equals-.dist_spec.Rd create mode 100644 tests/testthat/test-create_stan_params.R diff --git a/.github/workflows/lint-only-changed-files.yaml b/.github/workflows/lint-only-changed-files.yaml index 3d67b098d..a77c2d01d 100644 --- a/.github/workflows/lint-only-changed-files.yaml +++ b/.github/workflows/lint-only-changed-files.yaml @@ -29,6 +29,7 @@ jobs: any::gh any::lintr any::purrr + progressr - name: Add lintr options run: | diff --git a/NAMESPACE b/NAMESPACE index 0b7f2ef01..3ca6f47d8 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,8 @@ # Generated by roxygen2: do not edit by hand +S3method("!=",dist_spec) S3method("+",dist_spec) +S3method("==",dist_spec) S3method(c,dist_spec) S3method(collapse,dist_spec) S3method(collapse,multi_dist_spec) diff --git a/NEWS.md b/NEWS.md index 009ba40fc..28df3b2a2 100644 --- a/NEWS.md +++ b/NEWS.md @@ -15,8 +15,9 @@ estimate_infections() ``` - - A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs. - - A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk. +- A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs. +- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk. +- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and reviewed by @seabbs. ## Documentation diff --git a/R/create.R b/R/create.R index 2ffcfaf62..ed4524003 100644 --- a/R/create.R +++ b/R/create.R @@ -319,8 +319,6 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, # map settings to underlying gp stan requirements rt_data <- list( - r_mean = rt$prior$mean, - r_sd = rt$prior$sd, estimate_r = as.numeric(rt$use_rt), bp_n = ifelse(rt$use_breakpoints, max(breakpoints) - 1, 0), breakpoints = breakpoints, @@ -433,8 +431,6 @@ create_gp_data <- function(gp = gp_opts(), data) { ls_sdlog = convert_to_logsd(gp$ls_mean, gp$ls_sd), ls_min = gp$ls_min, ls_max = gp$ls_max, - alpha_mean = gp$alpha_mean, - alpha_sd = gp$alpha_sd, gp_type = data.table::fcase( gp$kernel == "se", 0, gp$kernel == "periodic", 1, @@ -472,7 +468,7 @@ create_gp_data <- function(gp = gp_opts(), data) { #' #' # Applying a observation scaling to the data #' create_obs_model( -#' obs_opts(scale = list(mean = 0.4, sd = 0.01)), dates = dates +#' obs_opts(scale = Normal(mean = 0.4, sd = 0.01)), dates = dates #' ) #' #' # Apply a custom week week length @@ -481,13 +477,9 @@ create_gp_data <- function(gp = gp_opts(), data) { create_obs_model <- function(obs = obs_opts(), dates) { data <- list( model_type = as.numeric(obs$family == "negbin"), - phi_mean = obs$phi$mean, - phi_sd = obs$phi$sd, week_effect = ifelse(obs$week_effect, obs$week_length, 1), obs_weight = obs$weight, - obs_scale = as.integer(obs$scale$sd > 0 || obs$scale$mean != 1), - obs_scale_mean = obs$scale$mean, - obs_scale_sd = obs$scale$sd, + obs_scale = as.integer(obs$scale != Fixed(1)), likelihood = as.numeric(obs$likelihood), return_likelihood = as.numeric(obs$return_likelihood) ) @@ -589,15 +581,30 @@ create_stan_data <- function(data, seeding_time, ) ) + # parameters + stan_data <- c( + stan_data, + create_stan_params( + alpha = gp$alpha, + R0 = rt$prior, + frac_obs = obs$scale, + rep_phi = obs$phi, + lower_bounds = c( + alpha = 0, + R0 = 0, + frac_obs = 0, + rep_phi = 0 + ) + ) + ) + # rescale mean shifted prior for back calculation if observation scaling is # used - if (stan_data$obs_scale == 1) { - stan_data$shifted_cases <- - stan_data$shifted_cases / stan_data$obs_scale_mean - stan_data$prior_infections <- log( - exp(stan_data$prior_infections) / stan_data$obs_scale_mean - ) - } + stan_data$shifted_cases <- + stan_data$shifted_cases / mean(obs$scale) + stan_data$prior_infections <- log( + exp(stan_data$prior_infections) / mean(obs$scale) + ) return(stan_data) } @@ -647,34 +654,15 @@ create_initial_conditions <- function(data) { out$rescaled_rho < data$ls_min, data$ls_min + 0.001, default = out$rescaled_rho )) - - out$alpha <- array( - truncnorm::rtruncnorm( - 1, a = 0, mean = data$alpha_mean, sd = data$alpha_sd - ) - ) } else { out$eta <- array(numeric(0)) out$rescaled_rho <- array(numeric(0)) - out$alpha <- array(numeric(0)) - } - if (data$model_type == 1) { - out$rep_phi <- array( - truncnorm::rtruncnorm( - 1, - a = 0, mean = data$phi_mean, sd = data$phi_sd - ) - ) } if (data$estimate_r == 1) { out$initial_infections <- array(rnorm(1, data$prior_infections, 0.2)) if (data$seeding_time > 1) { out$initial_growth <- array(rnorm(1, data$prior_growth, 0.02)) } - out$log_R <- array(rnorm( - n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd), - sd = convert_to_logsd(data$r_mean, data$r_sd) - )) } if (data$bp_n > 0) { @@ -684,20 +672,17 @@ create_initial_conditions <- function(data) { out$bp_sd <- array(numeric(0)) out$bp_effects <- array(numeric(0)) } - if (data$obs_scale_sd > 0) { - out$frac_obs <- array(truncnorm::rtruncnorm(1, - a = 0, b = 1, - mean = data$obs_scale_mean, - sd = data$obs_scale_sd - )) - } else { - out$frac_obs <- array(numeric(0)) - } if (data$week_effect > 0) { out$day_of_week_simplex <- array( rep(1 / data$week_effect, data$week_effect) ) } + out$params <- array(truncnorm::rtruncnorm( + data$n_params_variable, + a = data$params_lower, + b = data$params_upper, + mean = 0, sd = 1 + )) return(out) } return(init_fun) @@ -877,3 +862,94 @@ create_stan_delays <- function(..., time_points = 1L) { return(ret) } + +##' Create parameters for stan +##' +##' @param ... Named delay distributions. The names are assigned to IDs +##' @param lower_bounds Named vector of lower bounds for any delay(s). The names +##' have to correspond to the names given to the delay distributions passed. +##' If `NULL` (default) no parameters are given a lower bound. +##' @return A list of variables as expected by the stan model +##' @importFrom data.table fcase +##' @keywords internal +create_stan_params <- function(..., lower_bounds = NULL) { + params <- list(...) + + ## set IDs of any parameters that is NULL to 0 and remove + null_params <- vapply(params, is.null, logical(1)) + null_ids <- rep(0, sum(null_params)) + if (length(null_ids) > 0) { + names(null_ids) <- paste(names(null_params)[null_params], "id", sep = "_") + params <- params[!null_params] + } + + ## initialise variables + params_fixed_lookup <- rep(0L, length(params)) + params_variable_lookup <- rep(0L, length(params)) + + ## identify fixed/variable parameters + fixed <- vapply(params, get_distribution, character(1)) == "fixed" + params_fixed_lookup[fixed] <- seq_along(which(fixed)) + params_variable_lookup[!fixed] <- seq_along(which(!fixed)) + + ## lower bounds + params_lower <- rep(-Inf, length(params[!fixed])) + names(params_lower) <- names(params[!fixed]) + lower_bounds <- lower_bounds[names(params_lower)] + params_lower[names(lower_bounds)] <- lower_bounds + + ## upper bounds + params_upper <- vapply(params[!fixed], max, numeric(1)) + + ## prior distributions + prior_dist_name <- vapply(params[!fixed], get_distribution, character(1)) + prior_dist <- fcase( + prior_dist_name == "lognormal", 0L, + prior_dist_name == "gamma", 1L, + prior_dist_name == "normal", 2L + ) + ## parameters + prior_dist_params <- lapply(params[!fixed], get_parameters) + prior_dist_params_lengths <- lengths(prior_dist_params) + + ## check none of the parameters are uncertain + prior_uncertain <- vapply(prior_dist_params, function(x) { + !all(vapply(x, is.numeric, logical(1))) + }, logical(1)) + if (any(prior_uncertain)) { + uncertain_priors <- names(params[!fixed])[prior_uncertain] # nolint: object_usage_linter + cli_abort( + c( + "!" = "Parameter prior distribution{?s} for {.var {uncertain_priors}} + cannot have uncertain parameters." + ) + ) + } + + prior_dist_params <- unlist(prior_dist_params) + if (is.null(prior_dist_params)) { + prior_dist_params <- numeric(0) + } + + ## extract distributions and parameters + ret <- list( + n_params_variable = length(params) - sum(fixed), + n_params_fixed = sum(fixed), + params_lower = array(params_lower), + params_upper = array(params_upper), + params_fixed_lookup = array(params_fixed_lookup), + params_variable_lookup = array(params_variable_lookup), + params_value = array(vapply( + params[fixed], \(x) get_parameters(x)$value, numeric(1) + )), + prior_dist = array(prior_dist), + prior_dist_params_length = sum(prior_dist_params_lengths), + prior_dist_params = array(prior_dist_params) + ) + ids <- seq_along(params) + if (length(ids) > 0) { + names(ids) <- paste(names(params), "id", sep = "_") + } + ret <- c(ret, as.list(ids), as.list(null_ids)) + return(ret) +} diff --git a/R/dist_spec.R b/R/dist_spec.R index 1186a6438..37d368362 100644 --- a/R/dist_spec.R +++ b/R/dist_spec.R @@ -125,6 +125,57 @@ discrete_pmf <- function(distribution = c(e1, e2) } +##' Compares two delay distributions +##' +##' @param e1 The first delay distribution (of type ) to +##' combine. +##' +##' @param e2 The second delay distribution (of type ) to +##' combine. +##' @method == dist_spec +##' @return TRUE or FALSE +##' @export +##' @examples +##' Fixed(1) == Normal(1, 0.5) +## nolint start: cyclocomp_linter +`==.dist_spec` <- function(e1, e2) { + ## both must have same number of distributions + if (ndist(e1) != ndist(e2)) return(FALSE) + ## loop over constituent distributions + for (i in seq_len(ndist(e1))) { + ## distributions need to be the same + if (get_distribution(e1, i) != get_distribution(e2, i)) return(FALSE) + if (get_distribution(e1, i) == "nonparametric") { + ## if nonparametric then PMFs need to be the same + if (!identical(get_pmf(e1, i), get_pmf(e2, i))) return(FALSE) + } else { + ## if parametric then all parameters need to be the same + params1 <- get_parameters(e1, i) + params2 <- get_parameters(e2, i) + for (param in names(params1)) { + ## all parameters must be the same type + if ((is(params1[[param]], "dist_spec") && + is(params2[[param]], "dist_spec")) || + (is.numeric(params1[[param]]) && is.numeric(params2[[param]]))) { + ## if parameters are the same type they need to be same value + if (!(params1[[param]] == params2[[param]])) return(FALSE) + } else { + return(FALSE) + } + } + } + } + return(TRUE) +} +## nolint end: cyclocomp_linter + +##' @rdname equals-.dist_spec +##' @method != dist_spec +##' @export +`!=.dist_spec` <- function(e1, e2) { + !(e1 == e2) +} + #' Combines multiple delay distributions for further processing #' #' @description `r lifecycle::badge("experimental")` diff --git a/R/epinow.R b/R/epinow.R index 5baa06093..6ca1fd078 100644 --- a/R/epinow.R +++ b/R/epinow.R @@ -65,7 +65,7 @@ #' out <- epinow( #' data = reported_cases, #' generation_time = gt_opts(generation_time), -#' rt = rt_opts(prior = list(mean = 2, sd = 0.1)), +#' rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1)), #' delays = delay_opts(incubation_period + reporting_delay) #' ) #' # summary of the latest estimates diff --git a/R/estimate_infections.R b/R/estimate_infections.R index 8c1f457c9..1b4fff209 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -107,7 +107,7 @@ #' def <- estimate_infections(reported_cases, #' generation_time = gt_opts(generation_time), #' delays = delay_opts(incubation_period + reporting_delay), -#' rt = rt_opts(prior = list(mean = 2, sd = 0.1)) +#' rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1)) #' ) #' # real time estimates #' summary(def) diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 9c24eafe2..13054e956 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -101,7 +101,7 @@ #' # fit model to example data specifying a weak prior for fraction reported #' # with a secondary case #' inc <- estimate_secondary(cases[1:60], -#' obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE) +#' obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE) #' ) #' plot(inc, primary = TRUE) #' @@ -129,7 +129,7 @@ #' secondary = secondary_opts(type = "prevalence"), #' obs = obs_opts( #' week_effect = FALSE, -#' scale = list(mean = 0.4, sd = 0.1) +#' scale = Normal(mean = 0.4, sd = 0.1) #' ) #' ) #' plot(prev, primary = TRUE) @@ -250,6 +250,15 @@ estimate_secondary <- function(data, # observation model data stan_data <- c(stan_data, create_obs_model(obs, dates = reports$date)) + stan_data <- c(stan_data, create_stan_params( + frac_obs = obs$scale, + rep_phi = obs$phi, + lower_bounds = c( + frac_obs = 0, + rep_phi = 0 + ) + )) + # update data to use specified priors rather than defaults stan_data <- update_secondary_args(stan_data, priors = priors, verbose = verbose @@ -674,7 +683,7 @@ forecast_secondary <- function(estimate, # allocate empty parameters data <- allocate_empty( - data, c("frac_obs", "delay_params", "rep_phi"), + data, c("params", "delay_params"), n = data$n ) data$all_dates <- as.integer(all_dates) diff --git a/R/extract.R b/R/extract.R index 3c6d04489..cf8f74a9c 100644 --- a/R/extract.R +++ b/R/extract.R @@ -46,10 +46,12 @@ extract_parameter <- function(param, samples, dates) { #' value #' @keywords internal extract_static_parameter <- function(param, samples) { + id <- samples[[paste(param, "id", sep = "_")]] + lookup <- samples[["params_variable_lookup"]][id] data.table::data.table( parameter = param, - sample = seq_along(samples[[param]]), - value = samples[[param]] + sample = seq_along(samples[["params"]][, lookup]), + value = samples[["params"]][, lookup] ) } @@ -239,16 +241,9 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, } if (data$model_type == 1) { out$reporting_overdispersion <- extract_static_parameter("rep_phi", samples) - out$reporting_overdispersion <- out$reporting_overdispersion[, - value := value.V1][, - value.V1 := NULL - ] } if ("obs_scale_sd" %in% names(data) && data$obs_scale_sd > 0) { out$fraction_observed <- extract_static_parameter("frac_obs", samples) - out$fraction_observed <- out$fraction_observed[, value := value.V1][, - value.V1 := NULL - ] } return(out) } diff --git a/R/opts.R b/R/opts.R index d56a31a32..04d9e5577 100644 --- a/R/opts.R +++ b/R/opts.R @@ -297,9 +297,10 @@ trunc_opts <- function(dist = Fixed(0), default_cdf_cutoff = 0.001, #' reproduction number. Custom settings can be supplied which override the #' defaults. #' -#' @param prior List containing named numeric elements "mean" and "sd". The -#' mean and standard deviation of the log normal Rt prior. Defaults to mean of -#' 1 and standard deviation of 1. +#' @param prior A `` giving the prior of the initial reproduciton +#' number. Ignored if `use_rt` is `FALSE`. Defaults to a LogNormal distributin +#' with mean of 1 and standard deviation of 1: `LogNormal(mean = 1, sd = 1)`. +#' A lower limit of 0 will be enforced automatically. #' #' @param use_rt Logical, defaults to `TRUE`. Should Rt be used to generate #' infections and hence reported cases. @@ -339,11 +340,11 @@ trunc_opts <- function(dist = Fixed(0), default_cdf_cutoff = 0.001, #' rt_opts() #' #' # add a custom length scale -#' rt_opts(prior = list(mean = 2, sd = 1)) +#' rt_opts(prior = LogNormal(mean = 2, sd = 1)) #' #' # add a weekly random walk #' rt_opts(rw = 7) -rt_opts <- function(prior = list(mean = 1, sd = 1), +rt_opts <- function(prior = LogNormal(mean = 1, sd = 1), use_rt = TRUE, rw = 0, use_breakpoints = TRUE, @@ -351,7 +352,6 @@ rt_opts <- function(prior = list(mean = 1, sd = 1), gp_on = c("R_t-1", "R0"), pop = 0) { rt <- list( - prior = prior, use_rt = use_rt, rw = rw, use_breakpoints = use_breakpoints, @@ -365,15 +365,37 @@ rt_opts <- function(prior = list(mean = 1, sd = 1), rt$use_breakpoints <- TRUE } - if (!("mean" %in% names(rt$prior) && "sd" %in% names(rt$prior))) { - cli_abort( + if (is.list(prior) && !is(prior, "dist_spec")) { + cli_warn( c( - "!" = "{.var prior} must have both {.var mean} and {.var sd} - specified.", - "i" = "Did you forget to specify {.var mean} and/or {.var sd}?" + "!" = "Specifying {.var prior} as a list is deprecated.", + "i" = "Use a {.cls dist_spec} instead." ) ) + if (!("mean" %in% names(prior) && "sd" %in% names(prior))) { + cli_abort( + c( + "!" = "{.var prior} must have both {.var mean} and {.var sd} + specified.", + "i" = "Did you forget to specify {.var mean} and/or {.var sd}?" + ) + ) + } + prior <- LogNormal(mean = prior$mean, sd = prior$sd) } + + if (rt$use_rt) { + rt$prior <- prior + } else { + if (!missing(prior)) { + cli_warn( + c( + "!" = "Rt {.var prior} is ignored if {.var use_rt} is FALSE." + ) + ) + } + } + attr(rt, "class") <- c("rt_opts", class(rt)) return(rt) } @@ -453,14 +475,17 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"), #' scale. Updated in [create_gp_data()] to be the length of the input data if #' this is smaller. #' -#' @param alpha_mean Numeric, defaults to 0. The mean of the magnitude parameter -#' of the Gaussian process kernel. Should be approximately the expected standard -#' deviation of the Gaussian process (logged Rt in case of the renewal model, -#' logged infections in case of the nonmechanistic model). +#' @param alpha A `` giving the prior distribution of the magnitude +#' parameter of the Gaussian process kernel. Should be approximately the +#' expected standard deviation of the Gaussian process (logged Rt in case of +#' the renewal model, logged infections in case of the nonmechanistic model). +#' Defaults to a half-normal distribution with mean 0 and sd 0.01: +#' `Normal(mean = 0, sd = 0.01)` (a lower limit of 0 will be enforced +#' automatically to ensure positivity) #' -#' @param alpha_sd Numeric, defaults to 0.01. The standard deviation of the -#' magnitude parameter of the Gaussian process kernel. Can be tuned to adjust -#' how far alpha is allowed to deviate form its prior mean (`alpha_mean`). +#' @param alpha_mean Deprecated; use `alpha` instead. +#' +#' @param alpha_sd Deprecated; use `alpha` instead. #' #' @param kernel Character string, the type of kernel required. Currently #' supporting the Matern kernel ("matern"), squared exponential kernel ("se"), @@ -508,18 +533,28 @@ gp_opts <- function(basis_prop = 0.2, ls_sd = 7, ls_min = 0, ls_max = 60, - alpha_mean = 0, - alpha_sd = 0.01, + alpha = Normal(mean = 0, sd = 0.01), kernel = c("matern", "se", "ou", "periodic"), matern_order = 3 / 2, matern_type, - w0 = 1.0) { + w0 = 1.0, + alpha_mean, alpha_sd) { if (!missing(matern_type)) { lifecycle::deprecate_warn( "1.6.0", "gp_opts(matern_type)", "gp_opts(matern_order)" ) } + if (!missing(alpha_mean)) { + lifecycle::deprecate_warn( + "1.7.0", "gp_opts(alpha_mean)", "gp_opts(alpha)" + ) + } + if (!missing(alpha_sd)) { + lifecycle::deprecate_warn( + "1.7.0", "gp_opts(alpha_sd)", "gp_opts(alpha)" + ) + } if (!missing(matern_type)) { if (!missing(matern_order) && matern_type != matern_order) { @@ -557,8 +592,7 @@ gp_opts <- function(basis_prop = 0.2, ls_sd = ls_sd, ls_min = ls_min, ls_max = ls_max, - alpha_mean = alpha_mean, - alpha_sd = alpha_sd, + alpha = alpha, kernel = kernel, matern_order = matern_order, w0 = w0 @@ -575,13 +609,12 @@ gp_opts <- function(basis_prop = 0.2, #' model. Custom settings can be supplied which override the defaults. #' @param family Character string defining the observation model. Options are #' Negative binomial ("negbin"), the default, and Poisson. -#' @param phi Overdispersion parameter of the reporting process, used only if -#' `familiy` is "negbin". Can be supplied either as a single numeric value -#' (fixed overdispersion) or a list with numeric elements mean (`mean`) and -#' standard deviation (`sd`) defining a normally distributed prior. -#' Internally parameterised such that the overdispersion is one over the -#' square of this prior overdispersion. Defaults to a list with elements -#' `mean = 0` and `sd = 0.25`. +#' @param phi A `` specifying a prior on the overdispersion parameter +#' of the reporting process, used only if `familiy` is "negbin". Internally +#' parameterised such that the overdispersion is one over the square of this +#' prior overdispersion phi. Defaults to a half-normal distribution with mean +#' of 0 and standard deviation of 0.25: `Normal(mean = 0, sd = 0.25)`. A lower +#' limit of zero will be enforced automatically. #' @param weight Numeric, defaults to 1. Weight to give the observed data in the #' log density. #' @param week_effect Logical defaulting to `TRUE`. Should a day of the week @@ -589,11 +622,12 @@ gp_opts <- function(basis_prop = 0.2, #' @param week_length Numeric assumed length of the week in days, defaulting to #' 7 days. This can be modified if data aggregated over a period other than a #' week or if data has a non-weekly periodicity. -#' @param scale Scaling factor to be applied to map latent infections (convolved -#' to date of report). Can be supplied either as a single numeric value (fixed -#' scale) or a list with numeric elements mean (`mean`) and standard deviation -#' (`sd`) defining a normally distributed scaling factor. Defaults to 1, i.e. -#' no scaling. +#' @param scale A `` specifying a prior on the scaling factor to be +#' applied to map latent infections (convolved to date of report). Defaults +#' to a fixed value of 1, i.e. no scaling: `Fixed(1)`. A lower limit of zero +#' will be enforced automatically. If setting to a prior distribution and no +#' overreporting is expected, it might be sensible to set a maximum of 1 via +#' the `max` option when declaring the distribution. #' @param na Deprecated; use the [fill_missing()] function instead #' @param likelihood Logical, defaults to `TRUE`. Should the likelihood be #' included in the model. @@ -611,13 +645,13 @@ gp_opts <- function(basis_prop = 0.2, #' obs_opts(week_effect = TRUE) #' #' # Scale reported data -#' obs_opts(scale = list(mean = 0.2, sd = 0.02)) +#' obs_opts(scale = Normal(mean = 0.2, sd = 0.02)) obs_opts <- function(family = c("negbin", "poisson"), - phi = list(mean = 0, sd = 0.25), + phi = Normal(mean = 0, sd = 0.25), weight = 1, week_effect = TRUE, week_length = 7, - scale = 1, + scale = Fixed(1), na = c("missing", "accumulate"), likelihood = TRUE, return_likelihood = FALSE) { @@ -679,16 +713,32 @@ obs_opts <- function(family = c("negbin", "poisson"), for (param in c("phi", "scale")) { if (is.numeric(obs[[param]])) { - obs[[param]] <- list(mean = obs[[param]], sd = 0) - } - if (!(all(c("mean", "sd") %in% names(obs[[param]])))) { - cli_abort( + cli_warn( c( - "!" = "Both a {.var mean} and {.var sd} are needed if specifying - {.strong {param}} as list.", - "i" = "Did you forget to specify {.var mean} and/or {.var sd}?" + "!" = "Specifying {.var {param}} as a numeric value is deprecated.", + "i" = "Use a {.cls dist_spec} instead using {.fn Fixed()}." + ) + ) + obs[[param]] <- Fixed(obs[[param]]) + } else if (is.list(obs[[param]]) && !is(obs[[param]], "dist_spec")) { + cli_warn( + c( + "!" = "Specifying {.var {param}} as a list is deprecated.", + "i" = "Use a {.cls dist_spec} instead." ) ) + if (!(all(c("mean", "sd") %in% names(obs[[param]])))) { + cli_abort( + c( + "!" = "Both a {.var mean} and {.var sd} are needed if specifying + {.var {param}} as list.", + "i" = "Did you forget to specify {.var mean} and/or {.var sd}?" + ) + ) + } + obs[[param]] <- Normal(mean = obs[[param]]$mean, sd = obs[[param]]$sd) + } else { + assert_class(obs[[param]], "dist_spec") } } diff --git a/R/regional_epinow.R b/R/regional_epinow.R index 09d305031..a5f08d062 100644 --- a/R/regional_epinow.R +++ b/R/regional_epinow.R @@ -79,7 +79,7 @@ #' data = cases, #' generation_time = gt_opts(example_generation_time), #' delays = delay_opts(example_incubation_period + example_reporting_delay), -#' rt = rt_opts(prior = list(mean = 2, sd = 0.2)), +#' rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.2)), #' stan = stan_opts( #' samples = 100, warmup = 200 #' ), diff --git a/R/simulate_infections.R b/R/simulate_infections.R index 3170220e7..f3b66cc3f 100644 --- a/R/simulate_infections.R +++ b/R/simulate_infections.R @@ -152,7 +152,7 @@ simulate_infections <- function(estimates, R, initial_infections, obs, dates = R$date )) - if (data$obs_scale_sd > 0) { + if (get_distribution(obs$scale) != "fixed") { cli_abort( c( "!" = "Cannot simulate from uncertain observation scaling.", @@ -160,16 +160,9 @@ simulate_infections <- function(estimates, R, initial_infections, ) ) } - if (data$obs_scale) { - data$frac_obs <- array(data$obs_scale_mean, dim = c(1, 1)) - } else { - data$frac_obs <- array(dim = c(1, 0)) - } - data$obs_scale_mean <- NULL - data$obs_scale_sd <- NULL if (obs$family == "negbin") { - if (data$phi_sd > 0) { + if (get_distribution(obs$phi) != "fixed") { cli_abort( c( "!" = "Cannot simulate from uncertain overdispersion.", @@ -177,12 +170,18 @@ simulate_infections <- function(estimates, R, initial_infections, ) ) } - data$rep_phi <- array(data$phi_mean, dim = c(1, 1)) } else { - data$rep_phi <- array(dim = c(1, 0)) + obs$phi <- NULL } - data$phi_mean <- NULL - data$phi_sd <- NULL + + data <- c(data, create_stan_params( + alpha = NULL, + R0 = NULL, + frac_obs = obs$scale, + rep_phi = obs$phi + )) + ## set empty params matrix - variable parameters not supported here + data$params <- array(dim = c(1, 0)) ## day of week effect if (is.null(day_of_week_effect)) { @@ -278,8 +277,8 @@ simulate_infections <- function(estimates, R, initial_infections, #' est <- estimate_infections(reported_cases, #' generation_time = generation_time_opts(example_generation_time), #' delays = delay_opts(example_incubation_period + example_reporting_delay), -#' rt = rt_opts(prior = list(mean = 2, sd = 0.1), rw = 7), -#' obs = obs_opts(scale = list(mean = 0.1, sd = 0.01)), +#' rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1), rw = 7), +#' obs = obs_opts(scale = Normal(mean = 0.1, sd = 0.01)), #' gp = NULL, horizon = 0 #' ) #' @@ -436,7 +435,7 @@ forecast_infections <- function(estimates, ## allocate empty parameters data <- allocate_empty( - data, c("frac_obs", "delay_params", "rep_phi"), + data, c("delay_params", "params"), n = data$n ) diff --git a/R/simulate_secondary.R b/R/simulate_secondary.R index 0bcb82314..df112fbcd 100644 --- a/R/simulate_secondary.R +++ b/R/simulate_secondary.R @@ -94,7 +94,7 @@ simulate_secondary <- function(primary, obs, dates = primary$date )) - if (data$obs_scale_sd > 0) { + if (get_distribution(obs$scale) != "fixed") { cli_abort( c( "!" = "Cannot simulate from uncertain observation scaling.", @@ -102,16 +102,9 @@ simulate_secondary <- function(primary, ) ) } - if (data$obs_scale) { - data$frac_obs <- array(data$obs_scale_mean, dim = c(1, 1)) - } else { - data$frac_obs <- array(dim = c(1, 0)) - } - data$obs_scale_mean <- NULL - data$obs_scale_sd <- NULL if (obs$family == "negbin") { - if (data$phi_sd > 0) { + if (get_distribution(obs$phi) != "fixed") { cli_abort( c( "!" = "Cannot simulate from uncertain overdispersion.", @@ -119,12 +112,17 @@ simulate_secondary <- function(primary, ) ) } - data$rep_phi <- array(data$phi_mean, dim = c(1, 1)) } else { - data$rep_phi <- array(dim = c(1, 0)) + obs$phi <- NULL } - data$phi_mean <- NULL - data$phi_sd <- NULL + + data <- c(data, create_stan_params( + frac_obs = obs$scale, + rep_phi = obs$phi + )) + + ## set empty params matrix - variable parameters not supported here + data$params <- array(dim = c(1, 0)) ## day of week effect if (is.null(day_of_week_effect)) { diff --git a/data-raw/estimate-infections.R b/data-raw/estimate-infections.R index 24f1215c1..165a767d5 100644 --- a/data-raw/estimate-infections.R +++ b/data-raw/estimate-infections.R @@ -14,7 +14,7 @@ reporting_delay <- LogNormal(mean = 2, sd = 1, max = 10L) example_estimate_infections <- estimate_infections(reported_cases, generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.1)), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1)), stan = stan_opts(samples = 200, control = list(adapt_delta = 0.95)) ) @@ -28,7 +28,7 @@ example_regional_epinow <- regional_epinow( generation_time = gt_opts(example_generation_time), data = cases, delays = delay_opts(example_incubation_period + reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.2)), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.2)), stan = stan_opts(samples = 200, control = list(adapt_delta = 0.95)) ) diff --git a/inst/dev/benchmark-functions.R b/inst/dev/benchmark-functions.R index 3e884e72b..c03104463 100644 --- a/inst/dev/benchmark-functions.R +++ b/inst/dev/benchmark-functions.R @@ -17,7 +17,7 @@ create_profiles <- function(dir = file.path("inst", "stan"), data = reported_cases, generation_time = gt_opts(fixed_generation_time), delays = delay_opts(delays), - rt = rt_opts(prior = list(mean = 2, sd = 0.2)), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.2)), stan = stan_opts( samples = 1000, chains = 2, object = compiled_model, cores = 2 diff --git a/inst/dev/recover-synthetic/rt.R b/inst/dev/recover-synthetic/rt.R index a20428701..fd223ae3e 100644 --- a/inst/dev/recover-synthetic/rt.R +++ b/inst/dev/recover-synthetic/rt.R @@ -7,14 +7,14 @@ old_opts <- options() options(mc.cores = 4) #' get example delays -obs <- obs_opts(scale = list(mean = 0.1, sd = 0.025), return_likelihood = TRUE) +obs <- obs_opts(scale = Normal(mean = 0.1, sd = 0.025), return_likelihood = TRUE) # fit model to data to recover realistic parameter estimates and define settings # shared simulation settings init <- estimate_infections(example_confirmed[1:100], generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + example_reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.1), rw = 14), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1), rw = 14), gp = NULL, horizon = 0, obs = obs ) @@ -59,7 +59,7 @@ for (method in c("nuts")) { estimate_infections(sim_cases, generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + example_reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.25)), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.25)), stan = stanopts, obs = obs, horizon = 0 @@ -90,7 +90,7 @@ for (method in c("nuts")) { generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + example_reporting_delay), rt = rt_opts( - prior = list(mean = 2, sd = 0.25), + prior = LogNormal(mean = 2, sd = 0.25), rw = 7 ), gp = NULL, @@ -109,7 +109,7 @@ for (method in c("nuts")) { generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + example_reporting_delay), rt = rt_opts( - prior = list(mean = 2, sd = 0.25), rw = 14, gp_on = "R0" + prior = LogNormal(mean = 2, sd = 0.25), rw = 14, gp_on = "R0" ), stan = stanopts, obs = obs, @@ -130,7 +130,7 @@ for (method in c("nuts")) { example_incubation_period + example_reporting_delay ), rt = rt_opts( - prior = list(mean = 2, sd = 0.25), + prior = LogNormal(mean = 2, sd = 0.25), rw = 1 ), gp = NULL, diff --git a/inst/stan/data/estimate_infections_params.stan b/inst/stan/data/estimate_infections_params.stan new file mode 100644 index 000000000..3351f5ea3 --- /dev/null +++ b/inst/stan/data/estimate_infections_params.stan @@ -0,0 +1,4 @@ +int alpha_id; // parameter id of alpha (GP magnitude) +int R0_id; // parameter id of R0 +int frac_obs_id; // parameter id of frac_obs +int rep_phi_id; // parameter id of rep_phi_id diff --git a/inst/stan/data/estimate_secondary_params.stan b/inst/stan/data/estimate_secondary_params.stan new file mode 100644 index 000000000..736ce31df --- /dev/null +++ b/inst/stan/data/estimate_secondary_params.stan @@ -0,0 +1,2 @@ +int frac_obs_id; // parameter id of frac_obs +int rep_phi_id; // parameter id of rep_phi_id diff --git a/inst/stan/data/gaussian_process.stan b/inst/stan/data/gaussian_process.stan index 8154ffdfe..7990dba8a 100644 --- a/inst/stan/data/gaussian_process.stan +++ b/inst/stan/data/gaussian_process.stan @@ -4,8 +4,6 @@ real ls_sdlog; // sdlog for gp lengthscale prior real ls_min; // Lower bound for the lengthscale real ls_max; // Upper bound for the lengthscale - real alpha_mean; // mean of the alpha gp kernal parameter - real alpha_sd; // standard deviation of the alpha gp kernal parameter int gp_type; // type of gp, 0 = squared exponential, 1 = periodic, 2 = Matern real nu; // smoothness parameter for Matern kernel (used if gp_type = 2) real w0; // fundamental frequency for periodic kernel (used if gp_type = 1) diff --git a/inst/stan/data/observation_model.stan b/inst/stan/data/observation_model.stan index 0ce9ef3bb..76ad215ea 100644 --- a/inst/stan/data/observation_model.stan +++ b/inst/stan/data/observation_model.stan @@ -1,11 +1,7 @@ array[t - seeding_time] int day_of_week; // day of the week indicator (1 - 7) int model_type; // type of model: 0 = poisson otherwise negative binomial - real phi_mean; // Mean and sd of the normal prior for the - real phi_sd; // reporting process int week_effect; // length of week effect int obs_scale; // logical controlling scaling of observations - real obs_scale_mean; // mean scaling factor for observations - real obs_scale_sd; // standard deviation of observation scaling real obs_weight; // weight given to observation in log density int likelihood; // Should the likelihood be included in the model int return_likelihood; // Should the likehood be returned by the model diff --git a/inst/stan/data/params.stan b/inst/stan/data/params.stan new file mode 100644 index 000000000..5ac81a1c4 --- /dev/null +++ b/inst/stan/data/params.stan @@ -0,0 +1,13 @@ +int n_params_variable; // number of parameters +int n_params_fixed; // number of parameters +vector[n_params_variable] params_lower; // lower bounds of the priors +vector[n_params_variable] params_upper; // upper bounds of the priors + +array[n_params_fixed + n_params_variable] int params_fixed_lookup; // fixed parameter lookup +array[n_params_fixed + n_params_variable] int params_variable_lookup; // variable parameter lookup + +vector[n_params_fixed] params_value; // fixed parameter values + +array[n_params_variable] int prior_dist; // 0 = lognormal; 1 = gamma; 2 = normal +int prior_dist_params_length; // number of parameters across all parametric delay distributions +vector[prior_dist_params_length] prior_dist_params; diff --git a/inst/stan/data/rt.stan b/inst/stan/data/rt.stan index 11b1989ae..b736f1ade 100644 --- a/inst/stan/data/rt.stan +++ b/inst/stan/data/rt.stan @@ -1,8 +1,6 @@ int estimate_r; // should the reproduction no be estimated (1 = yes) real prior_infections; // prior for initial infections real prior_growth; // prior on initial growth rate - real r_mean; // prior mean of reproduction number - real r_sd; // prior standard deviation of reproduction number int bp_n; // no of breakpoints (0 = no breakpoints) array[t - seeding_time] int breakpoints; // when do breakpoints occur int future_fixed; // is underlying future Rt assumed to be fixed diff --git a/inst/stan/data/simulation_observation_model.stan b/inst/stan/data/simulation_observation_model.stan index c8cab6b35..2b83cab66 100644 --- a/inst/stan/data/simulation_observation_model.stan +++ b/inst/stan/data/simulation_observation_model.stan @@ -2,7 +2,5 @@ int week_effect; // should a day of the week effect be estimated array[n, week_effect] real day_of_week_simplex; int obs_scale; - array[n, obs_scale] real frac_obs; int model_type; - array[n, model_type] real rep_phi; // overdispersion of the reporting process int trunc_id; // id of truncation diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 7d7ea1018..8202c962c 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -7,6 +7,7 @@ functions { #include functions/infections.stan #include functions/observation_model.stan #include functions/generated_quantities.stan +#include functions/params.stan } data { @@ -16,6 +17,8 @@ data { #include data/rt.stan #include data/backcalc.stan #include data/observation_model.stan +#include data/params.stan +#include data/estimate_infections_params.stan } transformed data { @@ -27,9 +30,6 @@ transformed data { ot_h, t, horizon, estimate_r, stationary, future_fixed, fixed_from ); matrix[noise_terms, gp_type == 1 ? 2*M : M] PHI = setup_gp(M, L, noise_terms, gp_type == 1, w0); // basis function - // Rt - real r_logmean = log(r_mean^2 / sqrt(r_sd^2 + r_mean^2)); - real r_logsd = sqrt(log(1 + (r_sd^2 / r_mean^2))); array[delay_types] int delay_type_max; profile("assign max") { @@ -41,12 +41,11 @@ transformed data { } parameters { + vector[n_params_variable] params; // gaussian process array[fixed ? 0 : 1] real rescaled_rho; // length scale of noise GP - array[fixed ? 0 : 1] real alpha; // scale of noise GP vector[fixed ? 0 : gp_type == 1 ? 2*M : M] eta; // unconstrained noise // Rt - vector[estimate_r] log_R; // baseline reproduction number estimate (log) array[estimate_r] real initial_infections; // seed infections array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate array[bp_n > 0 ? 1 : 0] real bp_sd; // standard deviation of breakpoint effect @@ -54,8 +53,6 @@ parameters { // observation model vector[delay_params_length] delay_params; // delay parameters simplex[week_effect] day_of_week_simplex; // day of week reporting effect - array[obs_scale_sd > 0 ? 1 : 0] real frac_obs; // fraction of cases that are ultimately observed - array[model_type] real rep_phi; // overdispersion of the reporting process } transformed parameters { @@ -69,8 +66,12 @@ transformed parameters { // GP in noise - spectral densities profile("update gp") { if (!fixed) { + real alpha = get_param( + alpha_id, params_fixed_lookup, params_variable_lookup, params_value, + params + ); noise = update_gp( - PHI, M, L, alpha[1], rescaled_rho, eta, gp_type, nu + PHI, M, L, alpha, rescaled_rho, eta, gp_type, nu ); } } @@ -85,9 +86,12 @@ transformed parameters { 1, 1, 0 ); } - profile("R") { + profile("R0") { + real R0 = get_param( + R0_id, params_fixed_lookup, params_variable_lookup, params_value, params + ); R = update_Rt( - ot_h, log_R[estimate_r], noise, breakpoints, bp_effects, stationary + ot_h, R0, noise, breakpoints, bp_effects, stationary ); } profile("infections") { @@ -133,9 +137,11 @@ transformed parameters { // scaling of reported cases by fraction observed if (obs_scale) { profile("scale") { - reports = scale_obs( - reports, obs_scale_sd > 0 ? frac_obs[1] : obs_scale_mean + real frac_obs = get_param( + frac_obs_id, params_fixed_lookup, params_variable_lookup, params_value, + params ); + reports = scale_obs(reports, frac_obs); } } @@ -169,7 +175,7 @@ model { // priors for noise GP if (!fixed) { profile("gp lp") { - gaussian_process_lp(alpha[1], eta, alpha_mean, alpha_sd); + gaussian_process_lp(eta); if (gp_type != 3) { lengthscale_lp(rescaled_rho[1], ls_meanlog, ls_sdlog, ls_min, ls_max); } @@ -184,29 +190,32 @@ model { ); } + // parameter priors + profile("param lp") { + params_lp( + params, prior_dist, prior_dist_params, params_lower, params_upper + ); + } + if (estimate_r) { // priors on Rt profile("rt lp") { rt_lp( - log_R, initial_infections, initial_growth, bp_effects, bp_sd, bp_n, - seeding_time, r_logmean, r_logsd, prior_infections, prior_growth + initial_infections, initial_growth, bp_effects, bp_sd, bp_n, + seeding_time, prior_infections, prior_growth ); } } - // prior observation scaling - if (obs_scale_sd > 0) { - profile("scale lp") { - frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1]; - } - } - // observed reports from mean of reports (update likelihood) if (likelihood) { profile("report lp") { + real rep_phi = get_param( + rep_phi_id, params_fixed_lookup, params_variable_lookup, params_value, + params + ); report_lp( - cases, cases_time, obs_reports, rep_phi, phi_mean, phi_sd, model_type, - obs_weight + cases, cases_time, obs_reports, rep_phi, model_type, obs_weight ); } } @@ -220,6 +229,10 @@ generated quantities { vector[fixed ? 0 : 1] rho; profile("generated quantities") { + real rep_phi = get_param( + rep_phi_id, params_fixed_lookup, params_variable_lookup, params_value, + params + ); if (!fixed && gp_type != 3) { vector[noise_terms] x = linspaced_vector(noise_terms, 1, noise_terms); rho[1] = rescaled_rho[1] * sd(x); diff --git a/inst/stan/estimate_secondary.stan b/inst/stan/estimate_secondary.stan index 15836243c..4141b5bf1 100644 --- a/inst/stan/estimate_secondary.stan +++ b/inst/stan/estimate_secondary.stan @@ -4,6 +4,7 @@ functions { #include functions/delays.stan #include functions/observation_model.stan #include functions/secondary.stan +#include functions/params.stan } data { @@ -18,6 +19,8 @@ data { #include data/secondary.stan #include data/delays.stan #include data/observation_model.stan +#include data/params.stan +#include data/estimate_secondary_params.stan } transformed data{ @@ -31,8 +34,7 @@ parameters{ // observation model vector[delay_params_length] delay_params; simplex[week_effect] day_of_week_simplex; // day of week reporting effect - array[obs_scale] real frac_obs; // fraction of cases that are ultimately observed - array[model_type] real rep_phi; // overdispersion of the reporting process + vector[n_params_variable] params; } transformed parameters { @@ -45,7 +47,11 @@ transformed parameters { // scaling of primary reports by fraction observed if (obs_scale) { - scaled = scale_obs(primary, obs_scale_sd > 0 ? frac_obs[1] : obs_scale_mean); + real frac_obs = get_param( + frac_obs_id, params_fixed_lookup, params_variable_lookup, params_value, + params + ); + scaled = scale_obs(primary, frac_obs); } else { scaled = primary; } @@ -99,15 +105,21 @@ model { delay_dist, delay_weight ); - // prior primary report scaling - if (obs_scale) { - frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1]; - } + // parameter priors + profile("param lp") { + params_lp( + params, prior_dist, prior_dist_params, params_lower, params_upper + ); + } // observed secondary reports from mean of secondary reports (update likelihood) if (likelihood) { + real rep_phi = get_param( + rep_phi_id, params_fixed_lookup, params_variable_lookup, params_value, + params + ); report_lp( obs[(burn_in + 1):t][obs_time], obs_time, secondary[(burn_in + 1):t], - rep_phi, phi_mean, phi_sd, model_type, 1 + rep_phi, model_type, 1 ); } } @@ -115,11 +127,17 @@ model { generated quantities { array[t - burn_in] int sim_secondary; vector[return_likelihood > 1 ? t - burn_in : 0] log_lik; - // simulate secondary reports - sim_secondary = report_rng(secondary[(burn_in + 1):t], rep_phi, model_type); - // log likelihood of model - if (return_likelihood) { - log_lik = report_log_lik(obs[(burn_in + 1):t], secondary[(burn_in + 1):t], - rep_phi, model_type, obs_weight); + { + real rep_phi = get_param( + rep_phi_id, params_fixed_lookup, params_variable_lookup, params_value, + params + ); + // simulate secondary reports + sim_secondary = report_rng(secondary[(burn_in + 1):t], rep_phi, model_type); + // log likelihood of model + if (return_likelihood) { + log_lik = report_log_lik(obs[(burn_in + 1):t], secondary[(burn_in + 1):t], + rep_phi, model_type, obs_weight); + } } } diff --git a/inst/stan/functions/gaussian_process.stan b/inst/stan/functions/gaussian_process.stan index ab5ee5eb7..e35906f02 100644 --- a/inst/stan/functions/gaussian_process.stan +++ b/inst/stan/functions/gaussian_process.stan @@ -208,13 +208,8 @@ void lengthscale_lp(real rho, real ls_meanlog, real ls_sdlog, /** * Priors for Gaussian process (excluding length scale) * - * @param alpha Scaling parameter * @param eta Vector of noise terms - * @param alpha_mean Mean of alpha - * @param alpha_sd Standard deviation of alpha */ -void gaussian_process_lp(real alpha, vector eta, real alpha_mean, - real alpha_sd) { - alpha ~ normal(alpha_mean, alpha_sd) T[0,]; +void gaussian_process_lp(vector eta) { eta ~ std_normal(); } diff --git a/inst/stan/functions/observation_model.stan b/inst/stan/functions/observation_model.stan index 9a96d958e..1639725f4 100644 --- a/inst/stan/functions/observation_model.stan +++ b/inst/stan/functions/observation_model.stan @@ -97,24 +97,18 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd, * @param cases Array of integer observed cases. * @param cases_time Array of integer time indices for observed cases. * @param reports Vector of expected reports. - * @param rep_phi Array of real values for reporting overdispersion. - * @param phi_mean Real value for mean of reporting overdispersion prior. - * @param phi_sd Real value for standard deviation of reporting overdispersion prior. + * @param rep_phi Real values for reporting overdispersion. * @param model_type Integer indicating the model type (0 for Poisson, >0 for Negative Binomial). * @param weight Real value for weighting the log density contribution. * @param accumulate Array of integers indicating, for each time point, whether * to accumulate reports (1) or not (0). */ void report_lp(array[] int cases, array[] int cases_time, vector reports, - array[] real rep_phi, real phi_mean, real phi_sd, - int model_type, real weight) { + real rep_phi, int model_type, real weight) { int n = num_elements(cases_time); // number of observations vector[n] obs_reports = reports[cases_time]; // reports at observation time if (model_type) { - real dispersion = inv_square(phi_sd > 0 ? rep_phi[model_type] : phi_mean); - if (phi_sd > 0) { - rep_phi[model_type] ~ normal(phi_mean, phi_sd) T[0,]; - } + real dispersion = inv_square(rep_phi); if (weight == 1) { cases ~ neg_binomial_2(obs_reports, dispersion); } else { @@ -166,7 +160,7 @@ vector accumulate_reports(vector reports, array[] int accumulate) { * @return A vector of log likelihoods for each time point. */ vector report_log_lik(array[] int cases, vector reports, - array[] real rep_phi, int model_type, real weight) { + real rep_phi, int model_type, real weight) { int t = num_elements(reports); vector[t] log_lik; @@ -176,7 +170,7 @@ vector report_log_lik(array[] int cases, vector reports, log_lik[i] = poisson_lpmf(cases[i] | reports[i]) * weight; } } else { - real dispersion = inv_square(rep_phi[model_type]); + real dispersion = inv_square(rep_phi); for (i in 1:t) { log_lik[i] = neg_binomial_2_lpmf(cases[i] | reports[i], dispersion) * weight; } @@ -190,17 +184,17 @@ vector report_log_lik(array[] int cases, vector reports, * This function generates random samples of reported cases based on the specified model type. * * @param reports Vector of expected reports. - * @param rep_phi Array of real values for reporting overdispersion. + * @param rep_phi Real value for reporting overdispersion. * @param model_type Integer indicating the model type (0 for Poisson, >0 for Negative Binomial). * * @return An array of integer sampled reports. */ -array[] int report_rng(vector reports, array[] real rep_phi, int model_type) { +array[] int report_rng(vector reports, real rep_phi, int model_type) { int t = num_elements(reports); array[t] int sampled_reports; real dispersion = 1e5; if (model_type) { - dispersion = inv_square(rep_phi[model_type]); + dispersion = inv_square(rep_phi); } for (s in 1:t) { diff --git a/inst/stan/functions/params.stan b/inst/stan/functions/params.stan new file mode 100644 index 000000000..3861106c2 --- /dev/null +++ b/inst/stan/functions/params.stan @@ -0,0 +1,53 @@ +real get_param(int id, + array[] int params_fixed_lookup, + array[] int params_variable_lookup, + vector params_value, vector params) { + if (id == 0) { + return 0; // parameter not used + } else if (params_fixed_lookup[id]) { + return params_value[params_fixed_lookup[id]]; + } else { + return params[params_variable_lookup[id]]; + } +} + +vector get_param(int id, + array[] int params_fixed_lookup, + array[] int params_variable_lookup, + vector params_value, matrix params) { + int n_samples = rows(params); + if (id == 0) { + return rep_vector(0, n_samples) ; // parameter not used + } else if (params_fixed_lookup[id]) { + return rep_vector(params_value[params_fixed_lookup[id]], n_samples); + } else { + return params[, params_variable_lookup[id]]; + } +} + +void params_lp(vector params, array[] int prior_dist, + vector prior_dist_params, vector params_lower, + vector params_upper) { + int params_id = 1; + int num_params = num_elements(params); + for (id in 1:num_params) { + if (prior_dist[id] == 0) { // lognormal + params[id] ~ + lognormal(prior_dist_params[params_id], prior_dist_params[params_id + 1]) + T[params_lower[id], params_upper[id]]; + params_id += 2; + } else if (prior_dist[id] == 1) { + params[id] ~ + gamma(prior_dist_params[params_id], prior_dist_params[params_id + 1]) + T[params_lower[id], params_upper[id]]; + params_id += 2; + } else if (prior_dist[id] == 2) { + params[id] ~ + normal(prior_dist_params[params_id], prior_dist_params[params_id + 1]) + T[params_lower[id], params_upper[id]]; + params_id += 2; + } else { + reject("dist must be <= 2"); + } + } +} diff --git a/inst/stan/functions/rt.stan b/inst/stan/functions/rt.stan index ad2d877b1..e5ebb30f4 100644 --- a/inst/stan/functions/rt.stan +++ b/inst/stan/functions/rt.stan @@ -4,7 +4,7 @@ * process. * * @param t Length of the time series - * @param log_R Logarithm of the base reproduction number + * @param R0 Initial reproduction number * @param noise Vector of Gaussian process noise values * @param bps Array of breakpoint indices * @param bp_effects Vector of breakpoint effects @@ -12,19 +12,19 @@ * (1) or non-stationary (0) * @return A vector of length t containing the updated Rt values */ -vector update_Rt(int t, real log_R, vector noise, array[] int bps, +vector update_Rt(int t, real R0, vector noise, array[] int bps, vector bp_effects, int stationary) { // define control parameters int bp_n = num_elements(bp_effects); int gp_n = num_elements(noise); // initialise intercept - vector[t] R = rep_vector(log_R, t); + vector[t] logR = rep_vector(log(R0), t); //initialise breakpoints + rw if (bp_n) { vector[bp_n + 1] bp0; bp0[1] = 0; bp0[2:(bp_n + 1)] = cumulative_sum(bp_effects); - R = R + bp0[bps]; + logR = logR + bp0[bps]; } //initialise gaussian process if (gp_n) { @@ -39,32 +39,27 @@ vector update_Rt(int t, real log_R, vector noise, array[] int bps, gp[2:(gp_n + 1)] = noise; gp = cumulative_sum(gp); } - R = R + gp; + logR = logR + gp; } - return exp(R); + return exp(logR); } /** * Calculate the log-probability of the reproduction number (Rt) priors * - * @param log_R Logarithm of the base reproduction number * @param initial_infections Array of initial infection values * @param initial_growth Array of initial growth rates * @param bp_effects Vector of breakpoint effects * @param bp_sd Array of breakpoint standard deviations * @param bp_n Number of breakpoints * @param seeding_time Time point at which seeding occurs - * @param r_logmean Log-mean of the prior distribution for the base reproduction number - * @param r_logsd Log-standard deviation of the prior distribution for the base reproduction number * @param prior_infections Prior mean for initial infections * @param prior_growth Prior mean for initial growth rates */ -void rt_lp(vector log_R, array[] real initial_infections, array[] real initial_growth, +void rt_lp(array[] real initial_infections, array[] real initial_growth, vector bp_effects, array[] real bp_sd, int bp_n, int seeding_time, - real r_logmean, real r_logsd, real prior_infections, - real prior_growth) { - log_R ~ normal(r_logmean, r_logsd); + real prior_infections, real prior_growth) { //breakpoint effects on Rt if (bp_n > 0) { bp_sd[1] ~ normal(0, 0.1) T[0,]; diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index 245f80c49..3e8131994 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -7,6 +7,7 @@ functions { #include functions/infections.stan #include functions/observation_model.stan #include functions/generated_quantities.stan +#include functions/params.stan } data { @@ -21,6 +22,10 @@ data { #include data/simulation_delays.stan // observation model #include data/simulation_observation_model.stan + // parameters +#include data/params.stan +#include data/estimate_infections_params.stan + matrix[n, n_params_variable] params; // parameters } transformed data { @@ -36,66 +41,76 @@ generated quantities { matrix[n, t - seeding_time] reports; // observed cases array[n, t - seeding_time] int imputed_reports; matrix[n, t - seeding_time - 1] r; - for (i in 1:n) { - // generate infections from Rt trace - vector[delay_type_max[gt_id] + 1] gt_rev_pmf; - gt_rev_pmf = get_delay_rev_pmf( - gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 1, 1, 0 + { + vector[n] rep_phi = get_param( + rep_phi_id, params_fixed_lookup, params_variable_lookup, + params_value, params ); - - infections[i] = to_row_vector(generate_infections( - to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i], - initial_growth[i], pop, future_time - )); - - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, + vector[n] frac_obs = get_param( + frac_obs_id, params_fixed_lookup, params_variable_lookup, + params_value, params + ); + for (i in 1:n) { + // generate infections from Rt trace + vector[delay_type_max[gt_id] + 1] gt_rev_pmf; + gt_rev_pmf = get_delay_rev_pmf( + gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 0 - ); - // convolve from latent infections to mean of observations - reports[i] = to_row_vector(convolve_to_report( - to_vector(infections[i]), delay_rev_pmf, seeding_time) + 1, 1, 0 ); - } else { - reports[i] = to_row_vector( - infections[i, (seeding_time + 1):t] - ); - } - // weekly reporting effect - if (week_effect > 1) { - reports[i] = to_row_vector( - day_of_week_effect(to_vector(reports[i]), day_of_week, - to_vector(day_of_week_simplex[i]))); - } - // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 1 + infections[i] = to_row_vector(generate_infections( + to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i], + initial_growth[i], pop, future_time + )); + + if (delay_id) { + vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( + delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, + delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, + 0, 1, 0 + ); + // convolve from latent infections to mean of observations + reports[i] = to_row_vector(convolve_to_report( + to_vector(infections[i]), delay_rev_pmf, seeding_time) + ); + } else { + reports[i] = to_row_vector( + infections[i, (seeding_time + 1):t] + ); + } + + // weekly reporting effect + if (week_effect > 1) { + reports[i] = to_row_vector( + day_of_week_effect(to_vector(reports[i]), day_of_week, + to_vector(day_of_week_simplex[i]))); + } + // truncate near time cases to observed reports + if (trunc_id) { + vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( + trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, + delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, + 0, 1, 1 + ); + reports[i] = to_row_vector(truncate_obs( + to_vector(reports[i]), trunc_rev_cmf, 0) + ); + } + // scale observations + if (obs_scale) { + reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i])); + } + // simulate reported cases + imputed_reports[i] = report_rng( + to_vector(reports[i]), rep_phi[i], model_type ); - reports[i] = to_row_vector(truncate_obs( - to_vector(reports[i]), trunc_rev_cmf, 0) + r[i] = to_row_vector( + calculate_growth(to_vector(infections[i]), seeding_time + 1) ); } - // scale observations - if (obs_scale) { - reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i, 1])); - } - // simulate reported cases - imputed_reports[i] = report_rng( - to_vector(reports[i]), rep_phi[i], model_type - ); - r[i] = to_row_vector( - calculate_growth(to_vector(infections[i]), seeding_time + 1) - ); } } diff --git a/inst/stan/simulate_secondary.stan b/inst/stan/simulate_secondary.stan index 8bd4386f1..ab75ba040 100644 --- a/inst/stan/simulate_secondary.stan +++ b/inst/stan/simulate_secondary.stan @@ -4,6 +4,7 @@ functions { #include functions/delays.stan #include functions/observation_model.stan #include functions/secondary.stan +#include functions/params.stan } data { @@ -16,10 +17,11 @@ data { array[t - h] int obs; // observed secondary data matrix[n, t] primary; // observed primary data #include data/secondary.stan - // delay from infection to report #include data/simulation_delays.stan - // observation model #include data/simulation_observation_model.stan +#include data/params.stan +#include data/estimate_secondary_params.stan + matrix[n, n_params_variable] params; // parameters } transformed data { @@ -31,56 +33,66 @@ transformed data { generated quantities { array[n, all_dates ? t : h] int sim_secondary; - for (i in 1:n) { - vector[t] secondary; - vector[t] scaled; - vector[t] convolved = rep_vector(1e-5, t); + { + vector[n] rep_phi = get_param( + rep_phi_id, params_fixed_lookup, params_variable_lookup, + params_value, params + ); + vector[n] frac_obs = get_param( + frac_obs_id, params_fixed_lookup, params_variable_lookup, + params_value, params + ); + for (i in 1:n) { + vector[t] secondary; + vector[t] scaled; + vector[t] convolved = rep_vector(1e-5, t); - if (obs_scale) { - scaled = scale_obs(to_vector(primary[i]), frac_obs[i, 1]); - } else { - scaled = to_vector(primary[i]); - } + if (obs_scale) { + scaled = scale_obs(to_vector(primary[i]), frac_obs[i]); + } else { + scaled = to_vector(primary[i]); + } - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 0 + if (delay_id) { + vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( + delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, + delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, + 0, 1, 0 + ); + convolved = convolved + convolve_to_report(scaled, delay_rev_pmf, 0); + } else { + convolved = convolved + scaled; + } + + // calculate secondary reports from primary + secondary = calculate_secondary( + scaled, convolved, obs, cumulative, historic, primary_hist_additive, + current, primary_current_additive, t - h + 1 ); - convolved = convolved + convolve_to_report(scaled, delay_rev_pmf, 0); - } else { - convolved = convolved + scaled; - } - // calculate secondary reports from primary - secondary = calculate_secondary( - scaled, convolved, obs, cumulative, historic, primary_hist_additive, - current, primary_current_additive, t - h + 1 - ); + // weekly reporting effect + if (week_effect > 1) { + secondary = day_of_week_effect(secondary, day_of_week, to_vector(day_of_week_simplex[i])); + } - // weekly reporting effect - if (week_effect > 1) { - secondary = day_of_week_effect(secondary, day_of_week, to_vector(day_of_week_simplex[i])); - } + // truncate near time cases to observed reports + if (trunc_id) { + vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( + trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, + delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, + 0, 1, 1 + ); + secondary = truncate_obs( + secondary, trunc_rev_cmf, 0 + ); + } - // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 1 - ); - secondary = truncate_obs( - secondary, trunc_rev_cmf, 0 + // simulate secondary reports + sim_secondary[i] = report_rng( + tail(secondary, all_dates ? t : h), rep_phi[i], model_type ); } - - // simulate secondary reports - sim_secondary[i] = report_rng( - tail(secondary, all_dates ? t : h), rep_phi[i], model_type - ); } } diff --git a/man/create_obs_model.Rd b/man/create_obs_model.Rd index 736385fcf..b2c743fde 100644 --- a/man/create_obs_model.Rd +++ b/man/create_obs_model.Rd @@ -32,7 +32,7 @@ create_obs_model(obs_opts(family = "poisson"), dates = dates) # Applying a observation scaling to the data create_obs_model( - obs_opts(scale = list(mean = 0.4, sd = 0.01)), dates = dates + obs_opts(scale = Normal(mean = 0.4, sd = 0.01)), dates = dates ) # Apply a custom week week length diff --git a/man/create_stan_params.Rd b/man/create_stan_params.Rd new file mode 100644 index 000000000..6a2e11bdc --- /dev/null +++ b/man/create_stan_params.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/create.R +\name{create_stan_params} +\alias{create_stan_params} +\title{Create parameters for stan} +\usage{ +create_stan_params(..., lower_bounds = NULL) +} +\arguments{ +\item{...}{Named delay distributions. The names are assigned to IDs} + +\item{lower_bounds}{Named vector of lower bounds for any delay(s). The names +have to correspond to the names given to the delay distributions passed. +If \code{NULL} (default) no parameters are given a lower bound.} +} +\value{ +A list of variables as expected by the stan model +} +\description{ +Create parameters for stan +} +\keyword{internal} diff --git a/man/epinow.Rd b/man/epinow.Rd index 2654dfe0b..bfe6dc102 100644 --- a/man/epinow.Rd +++ b/man/epinow.Rd @@ -166,7 +166,7 @@ reported_cases <- example_confirmed[1:40] out <- epinow( data = reported_cases, generation_time = gt_opts(generation_time), - rt = rt_opts(prior = list(mean = 2, sd = 0.1)), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1)), delays = delay_opts(incubation_period + reporting_delay) ) # summary of the latest estimates diff --git a/man/equals-.dist_spec.Rd b/man/equals-.dist_spec.Rd new file mode 100644 index 000000000..879c0331d --- /dev/null +++ b/man/equals-.dist_spec.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dist_spec.R +\name{==.dist_spec} +\alias{==.dist_spec} +\alias{!=.dist_spec} +\title{Compares two delay distributions} +\usage{ +\method{==}{dist_spec}(e1, e2) + +\method{!=}{dist_spec}(e1, e2) +} +\arguments{ +\item{e1}{The first delay distribution (of type ) to +combine.} + +\item{e2}{The second delay distribution (of type ) to +combine.} +} +\value{ +TRUE or FALSE +} +\description{ +Compares two delay distributions +} +\examples{ +Fixed(1) == Normal(1, 0.5) +} diff --git a/man/estimate_infections.Rd b/man/estimate_infections.Rd index 528a76edf..0372966a0 100644 --- a/man/estimate_infections.Rd +++ b/man/estimate_infections.Rd @@ -152,7 +152,7 @@ reporting_delay <- LogNormal(mean = 2, sd = 1, max = 10) def <- estimate_infections(reported_cases, generation_time = gt_opts(generation_time), delays = delay_opts(incubation_period + reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.1)) + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1)) ) # real time estimates summary(def) diff --git a/man/estimate_secondary.Rd b/man/estimate_secondary.Rd index b985c121c..fe526b41d 100644 --- a/man/estimate_secondary.Rd +++ b/man/estimate_secondary.Rd @@ -142,7 +142,7 @@ cases <- convolve_and_scale(cases, type = "incidence") # fit model to example data specifying a weak prior for fraction reported # with a secondary case inc <- estimate_secondary(cases[1:60], - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE) + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE) ) plot(inc, primary = TRUE) @@ -170,7 +170,7 @@ prev <- estimate_secondary(cases[1:100], secondary = secondary_opts(type = "prevalence"), obs = obs_opts( week_effect = FALSE, - scale = list(mean = 0.4, sd = 0.1) + scale = Normal(mean = 0.4, sd = 0.1) ) ) plot(prev, primary = TRUE) diff --git a/man/forecast_infections.Rd b/man/forecast_infections.Rd index e9c4fbde4..24c5b2c5d 100644 --- a/man/forecast_infections.Rd +++ b/man/forecast_infections.Rd @@ -65,8 +65,8 @@ reported_cases <- example_confirmed[1:50] est <- estimate_infections(reported_cases, generation_time = generation_time_opts(example_generation_time), delays = delay_opts(example_incubation_period + example_reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.1), rw = 7), - obs = obs_opts(scale = list(mean = 0.1, sd = 0.01)), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1), rw = 7), + obs = obs_opts(scale = Normal(mean = 0.1, sd = 0.01)), gp = NULL, horizon = 0 ) diff --git a/man/gp_opts.Rd b/man/gp_opts.Rd index 4b21c2494..3bbe91930 100644 --- a/man/gp_opts.Rd +++ b/man/gp_opts.Rd @@ -11,12 +11,13 @@ gp_opts( ls_sd = 7, ls_min = 0, ls_max = 60, - alpha_mean = 0, - alpha_sd = 0.01, + alpha = Normal(mean = 0, sd = 0.01), kernel = c("matern", "se", "ou", "periodic"), matern_order = 3/2, matern_type, - w0 = 1 + w0 = 1, + alpha_mean, + alpha_sd ) } \arguments{ @@ -45,14 +46,13 @@ process length scale will be used with recommended parameters scale. Updated in \code{\link[=create_gp_data]{create_gp_data()}} to be the length of the input data if this is smaller.} -\item{alpha_mean}{Numeric, defaults to 0. The mean of the magnitude parameter -of the Gaussian process kernel. Should be approximately the expected standard -deviation of the Gaussian process (logged Rt in case of the renewal model, -logged infections in case of the nonmechanistic model).} - -\item{alpha_sd}{Numeric, defaults to 0.01. The standard deviation of the -magnitude parameter of the Gaussian process kernel. Can be tuned to adjust -how far alpha is allowed to deviate form its prior mean (\code{alpha_mean}).} +\item{alpha}{A \verb{} giving the prior distribution of the magnitude +parameter of the Gaussian process kernel. Should be approximately the +expected standard deviation of the Gaussian process (logged Rt in case of +the renewal model, logged infections in case of the nonmechanistic model). +Defaults to a half-normal distribution with mean 0 and sd 0.01: +\code{Normal(mean = 0, sd = 0.01)} (a lower limit of 0 will be enforced +automatically to ensure positivity)} \item{kernel}{Character string, the type of kernel required. Currently supporting the Matern kernel ("matern"), squared exponential kernel ("se"), @@ -69,6 +69,10 @@ Kernel to use. Currently, the orders 1/2, 3/2, 5/2 and Inf are supported.} \item{w0}{Numeric, defaults to 1.0. Fundamental frequency for periodic kernel. They are only used if \code{kernel} is set to "periodic".} + +\item{alpha_mean}{Deprecated; use \code{alpha} instead.} + +\item{alpha_sd}{Deprecated; use \code{alpha} instead.} } \value{ A \verb{} object of settings defining the Gaussian process diff --git a/man/obs_opts.Rd b/man/obs_opts.Rd index 432f1c3af..044e7a97c 100644 --- a/man/obs_opts.Rd +++ b/man/obs_opts.Rd @@ -6,11 +6,11 @@ \usage{ obs_opts( family = c("negbin", "poisson"), - phi = list(mean = 0, sd = 0.25), + phi = Normal(mean = 0, sd = 0.25), weight = 1, week_effect = TRUE, week_length = 7, - scale = 1, + scale = Fixed(1), na = c("missing", "accumulate"), likelihood = TRUE, return_likelihood = FALSE @@ -20,13 +20,12 @@ obs_opts( \item{family}{Character string defining the observation model. Options are Negative binomial ("negbin"), the default, and Poisson.} -\item{phi}{Overdispersion parameter of the reporting process, used only if -\code{familiy} is "negbin". Can be supplied either as a single numeric value -(fixed overdispersion) or a list with numeric elements mean (\code{mean}) and -standard deviation (\code{sd}) defining a normally distributed prior. -Internally parameterised such that the overdispersion is one over the -square of this prior overdispersion. Defaults to a list with elements -\code{mean = 0} and \code{sd = 0.25}.} +\item{phi}{A \verb{} specifying a prior on the overdispersion parameter +of the reporting process, used only if \code{familiy} is "negbin". Internally +parameterised such that the overdispersion is one over the square of this +prior overdispersion phi. Defaults to a half-normal distribution with mean +of 0 and standard deviation of 0.25: \code{Normal(mean = 0, sd = 0.25)}. A lower +limit of zero will be enforced automatically.} \item{weight}{Numeric, defaults to 1. Weight to give the observed data in the log density.} @@ -38,11 +37,12 @@ effect be used in the observation model.} 7 days. This can be modified if data aggregated over a period other than a week or if data has a non-weekly periodicity.} -\item{scale}{Scaling factor to be applied to map latent infections (convolved -to date of report). Can be supplied either as a single numeric value (fixed -scale) or a list with numeric elements mean (\code{mean}) and standard deviation -(\code{sd}) defining a normally distributed scaling factor. Defaults to 1, i.e. -no scaling.} +\item{scale}{A \verb{} specifying a prior on the scaling factor to be +applied to map latent infections (convolved to date of report). Defaults +to a fixed value of 1, i.e. no scaling: \code{Fixed(1)}. A lower limit of zero +will be enforced automatically. If setting to a prior distribution and no +overreporting is expected, it might be sensible to set a maximum of 1 via +the \code{max} option when declaring the distribution.} \item{na}{Deprecated; use the \code{\link[=fill_missing]{fill_missing()}} function instead} @@ -68,5 +68,5 @@ obs_opts() obs_opts(week_effect = TRUE) # Scale reported data -obs_opts(scale = list(mean = 0.2, sd = 0.02)) +obs_opts(scale = Normal(mean = 0.2, sd = 0.02)) } diff --git a/man/regional_epinow.Rd b/man/regional_epinow.Rd index 4a1c754f7..6eca0c83a 100644 --- a/man/regional_epinow.Rd +++ b/man/regional_epinow.Rd @@ -156,7 +156,7 @@ def <- regional_epinow( data = cases, generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + example_reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.2)), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.2)), stan = stan_opts( samples = 100, warmup = 200 ), diff --git a/man/rt_opts.Rd b/man/rt_opts.Rd index 24774d891..c48cfe21e 100644 --- a/man/rt_opts.Rd +++ b/man/rt_opts.Rd @@ -5,7 +5,7 @@ \title{Time-Varying Reproduction Number Options} \usage{ rt_opts( - prior = list(mean = 1, sd = 1), + prior = LogNormal(mean = 1, sd = 1), use_rt = TRUE, rw = 0, use_breakpoints = TRUE, @@ -15,9 +15,10 @@ rt_opts( ) } \arguments{ -\item{prior}{List containing named numeric elements "mean" and "sd". The -mean and standard deviation of the log normal Rt prior. Defaults to mean of -1 and standard deviation of 1.} +\item{prior}{A \verb{} giving the prior of the initial reproduciton +number. Ignored if \code{use_rt} is \code{FALSE}. Defaults to a LogNormal distributin +with mean of 1 and standard deviation of 1: \code{LogNormal(mean = 1, sd = 1)}. +A lower limit of 0 will be enforced automatically.} \item{use_rt}{Logical, defaults to \code{TRUE}. Should Rt be used to generate infections and hence reported cases.} @@ -69,7 +70,7 @@ defaults. rt_opts() # add a custom length scale -rt_opts(prior = list(mean = 2, sd = 1)) +rt_opts(prior = LogNormal(mean = 2, sd = 1)) # add a weekly random walk rt_opts(rw = 7) diff --git a/tests/testthat/test-create_gp_data.R b/tests/testthat/test-create_gp_data.R index b1f7cc765..d33dccd76 100644 --- a/tests/testthat/test-create_gp_data.R +++ b/tests/testthat/test-create_gp_data.R @@ -11,7 +11,7 @@ test_that("create_gp_data returns correct default values when GP is disabled", { expect_equal(gp_data$ls_sdlog, convert_to_logsd(21, 7)) expect_equal(gp_data$ls_min, 0) expect_equal(gp_data$ls_max, 3.54, tolerance = 0.01) - expect_equal(gp_data$alpha_sd, 0.01) + expect_equal(gp_data$alpha, NULL) expect_equal(gp_data$gp_type, 2) # Default to Matern expect_equal(gp_data$nu, 3 / 2) expect_equal(gp_data$w0, 1.0) diff --git a/tests/testthat/test-create_obs_model.R b/tests/testthat/test-create_obs_model.R index a9115e261..8a767f6d2 100644 --- a/tests/testthat/test-create_obs_model.R +++ b/tests/testthat/test-create_obs_model.R @@ -3,10 +3,9 @@ dates <- seq(as.Date("2020-03-15"), by = "days", length.out = 15) test_that("create_obs_model works with default settings", { obs <- create_obs_model(dates = dates) - expect_equal(length(obs), 11) + expect_equal(length(obs), 7) expect_equal(names(obs), c( - "model_type", "phi_mean", "phi_sd", "week_effect", "obs_weight", - "obs_scale", "obs_scale_mean", "obs_scale_sd", + "model_type", "week_effect", "obs_weight", "obs_scale", "likelihood", "return_likelihood", "day_of_week" )) expect_equal(obs$model_type, 1) @@ -15,8 +14,6 @@ test_that("create_obs_model works with default settings", { expect_equal(obs$likelihood, 1) expect_equal(obs$return_likelihood, 0) expect_equal(obs$day_of_week, c(7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7)) - expect_equal(obs$obs_scale_mean, 1) - expect_equal(obs$obs_scale_sd, 0) }) test_that("create_obs_model can be used with a Poisson model", { @@ -24,24 +21,6 @@ test_that("create_obs_model can be used with a Poisson model", { expect_equal(obs$model_type, 0) }) -test_that("create_obs_model can be used with a scaling", { - obs <- create_obs_model( - dates = dates, - obs = obs_opts(scale = list(mean = 0.4, sd = 0.01)) - ) - expect_equal(obs$obs_scale_mean, 0.4) - expect_equal(obs$obs_scale_sd, 0.01) -}) - -test_that("create_obs_model can be used with fixed scaling", { - obs <- create_obs_model( - dates = dates, - obs = obs_opts(scale = 0.4) - ) - expect_equal(obs$obs_scale_mean, 0.4) - expect_equal(obs$obs_scale_sd, 0) -}) - test_that("create_obs_model can be used with no week effect", { obs <- create_obs_model(dates = dates, obs = obs_opts(week_effect = FALSE)) expect_equal(obs$week_effect, 1) @@ -52,18 +31,3 @@ test_that("create_obs_model can be used with a custom week length", { obs <- create_obs_model(dates = dates, obs = obs_opts(week_length = 3)) expect_equal(obs$day_of_week, c(3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2)) }) - -test_that("create_obs_model can be used with a user set phi", { - obs <- create_obs_model( - dates = dates, obs = obs_opts(phi = list(mean = 10, sd = 0.1)) - ) - expect_equal(obs$phi_mean, 10) - expect_equal(obs$phi_sd, 0.1) - obs <- create_obs_model( - dates = dates, - obs = obs_opts(phi = 0.5) - ) - expect_equal(obs$phi_mean, 0.5) - expect_equal(obs$phi_sd, 0) - expect_error(obs_opts(phi = c("Hi", "World"))) -}) diff --git a/tests/testthat/test-create_rt_date.R b/tests/testthat/test-create_rt_date.R index 748ae80d4..3fc47b1cc 100644 --- a/tests/testthat/test-create_rt_date.R +++ b/tests/testthat/test-create_rt_date.R @@ -2,8 +2,6 @@ test_that("create_rt_data returns expected default values", { result <- create_rt_data() expect_type(result, "list") - expect_equal(result$r_mean, 1) - expect_equal(result$r_sd, 1) expect_equal(result$estimate_r, 1) expect_equal(result$bp_n, 0) expect_equal(result$breakpoints, numeric(0)) @@ -24,7 +22,6 @@ test_that("create_rt_data handles NULL rt input correctly", { test_that("create_rt_data handles custom rt_opts correctly", { custom_rt <- rt_opts( - prior = list(mean = 2, sd = 0.5), use_rt = FALSE, rw = 0, use_breakpoints = FALSE, @@ -35,8 +32,6 @@ test_that("create_rt_data handles custom rt_opts correctly", { result <- create_rt_data(rt = custom_rt, horizon = 7) - expect_equal(result$r_mean, 2) - expect_equal(result$r_sd, 0.5) expect_equal(result$estimate_r, 0) expect_equal(result$pop, 1000000) expect_equal(result$stationary, 1) diff --git a/tests/testthat/test-create_stan_params.R b/tests/testthat/test-create_stan_params.R new file mode 100644 index 000000000..f0b8a7b4b --- /dev/null +++ b/tests/testthat/test-create_stan_params.R @@ -0,0 +1,53 @@ +test_that("create_stan_params can be used with a scaling", { + obs <- obs_opts(scale = Normal(mean = 0.4, sd = 0.01)) + params <- create_stan_params( + frac_obs = obs$scale, lower_bounds = c(frac_obs = 0) + ) + expect_equal(params$prior_dist, array(2L)) + expect_equal(params$prior_dist_params, array(c(0.4, 0.01))) + expect_equal(params$params_lower, array(0)) + expect_equal(params$frac_obs_id, 1L) +}) + +test_that("create_stan_params can be used with fixed scaling", { + obs <- obs_opts(scale = Fixed(0.4)) + params <- create_stan_params( + frac_obs = obs$scale + ) + expect_equal(params$params_value, array(0.4)) + expect_equal(length(params$prior_dist_params), 0L) +}) + +test_that("create_stan_params can be used with a user set phi", { + obs <- obs_opts( + phi = Normal(mean = 10, sd = 0.1) + ) + params <- create_stan_params( + phi = obs$phi + ) + expect_equal(params$prior_dist, array(2L)) + expect_equal(params$prior_dist_params, array(c(10, 0.1))) + expect_equal(params$phi_id, 1L) +}) + +test_that("create_stan_params can be used with fixed phi", { + obs <- obs_opts(phi = Fixed(0.5)) + params <- create_stan_params( + phi = obs$phi + ) + expect_equal(params$params_value, array(0.5)) + expect_equal(length(params$prior_dist_params), 0L) +}) + +test_that("create_stan_params can be used with NULL parameters", { + params <- create_stan_params( + test = NULL + ) + expect_equal(params$test_id, 0) +}) + +test_that("create_stan_params warns about uncertain parameters", { + expect_error(create_stan_params( + test = Normal(mean = 0, sd = Normal(1, 1)) + ), "cannot have uncertain parameters") +}) diff --git a/tests/testthat/test-estimate_secondary.R b/tests/testthat/test-estimate_secondary.R index dfb9277a6..669679507 100644 --- a/tests/testthat/test-estimate_secondary.R +++ b/tests/testthat/test-estimate_secondary.R @@ -23,14 +23,16 @@ inc_cases[ # fit model to example data specifying a weak prior for fraction reported # with a secondary case inc <- estimate_secondary(inc_cases[1:60], - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + obs = obs_opts( + scale = Normal(mean = 0.2, sd = 0.2, max = 1), week_effect = FALSE + ), verbose = FALSE ) # extract posterior variables of interest params <- c( "meanlog" = "delay_params[1]", "sdlog" = "delay_params[2]", - "scaling" = "frac_obs[1]" + "scaling" = "params[1]" ) inc_posterior <- inc$posterior[variable %in% params] @@ -58,7 +60,7 @@ prev <- estimate_secondary(prev_cases[1:100], secondary = secondary_opts(type = "prevalence"), obs = obs_opts( week_effect = FALSE, - scale = list(mean = 0.4, sd = 0.1) + scale = Normal(mean = 0.4, sd = 0.1) ), verbose = FALSE ) @@ -90,7 +92,7 @@ test_that("estimate_secondary successfully returns estimates when passed NA valu delays = delay_opts( LogNormal(meanlog = 1.8, sdlog = 0.5, max = 30) ), - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), verbose = FALSE ) prev_cases_na <- data.table::copy(prev_cases) @@ -100,7 +102,7 @@ test_that("estimate_secondary successfully returns estimates when passed NA valu delays = delay_opts( LogNormal(mean = 1.8, sd = 0.5, max = 30) ), - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), verbose = FALSE ) expect_true(is.list(inc_na$data)) @@ -125,7 +127,7 @@ test_that("estimate_secondary successfully returns estimates when accumulating t ) ), obs = obs_opts( - scale = list(mean = 0.4, sd = 0.05), week_effect = FALSE + scale = Normal(mean = 0.4, sd = 0.05), week_effect = FALSE ), verbose = FALSE ) expect_true(is.list(inc_weekly$data)) @@ -133,7 +135,7 @@ test_that("estimate_secondary successfully returns estimates when accumulating t test_that("estimate_secondary works when only estimating scaling", { inc <- estimate_secondary(inc_cases[1:60], - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), delay = delay_opts(), verbose = FALSE ) @@ -162,7 +164,7 @@ test_that("estimate_secondary can recover simulated parameters with the skip_on_os("windows") output <- capture.output(suppressMessages(suppressWarnings( inc_cmdstanr <- estimate_secondary(inc_cases[1:60], - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), verbose = FALSE, stan = stan_opts(backend = "cmdstanr") ) ))) @@ -215,7 +217,7 @@ test_that("estimate_secondary works with weigh_delay_priors = TRUE", { ) inc_weigh <- estimate_secondary( inc_cases[1:60], delays = delay_opts(delays), - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), weigh_delay_priors = TRUE, verbose = FALSE ) expect_s3_class(inc_weigh, "estimate_secondary") @@ -225,7 +227,7 @@ test_that("estimate_secondary works with filter_leading_zeros set", { modified_data <- inc_cases[1:10, secondary := 0] out <- estimate_secondary( modified_data, - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), filter_leading_zeros = TRUE, verbose = FALSE @@ -239,7 +241,7 @@ test_that("estimate_secondary works with zero_threshold set", { modified_data <- inc_cases[sample(1:30, 10), primary := 0] out <- estimate_secondary( modified_data, - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), zero_threshold = 10, verbose = FALSE diff --git a/tests/testthat/test-gp_opts.R b/tests/testthat/test-gp_opts.R index cd848b75c..47e4f6186 100644 --- a/tests/testthat/test-gp_opts.R +++ b/tests/testthat/test-gp_opts.R @@ -6,7 +6,7 @@ test_that("gp_opts returns correct default values", { expect_equal(gp$ls_sd, 7) expect_equal(gp$ls_min, 0) expect_equal(gp$ls_max, 60) - expect_equal(gp$alpha_sd, 0.01) + expect_equal(gp$alpha, Normal(0, 0.01)) expect_equal(gp$kernel, "matern") expect_equal(gp$matern_order, 3 / 2) expect_equal(gp$w0, 1.0) diff --git a/tests/testthat/test-obs_opts.R b/tests/testthat/test-obs_opts.R index a90b3e6fe..48b986e3b 100644 --- a/tests/testthat/test-obs_opts.R +++ b/tests/testthat/test-obs_opts.R @@ -6,7 +6,7 @@ test_that("obs_opts returns expected default values", { expect_equal(result$weight, 1) expect_true(result$week_effect) expect_equal(result$week_length, 7L) - expect_equal(result$scale, list(mean = 1, sd = 0)) + expect_equal(result$scale, Normal(mean = 1, sd = 0)) expect_equal(result$accumulate, 0) expect_true(result$likelihood) expect_false(result$return_likelihood) diff --git a/tests/testthat/test-rt_opts.R b/tests/testthat/test-rt_opts.R index 0a39be027..69fb1183c 100644 --- a/tests/testthat/test-rt_opts.R +++ b/tests/testthat/test-rt_opts.R @@ -2,7 +2,7 @@ test_that("rt_opts returns expected default values", { result <- rt_opts() expect_s3_class(result, "rt_opts") - expect_equal(result$prior, list(mean = 1, sd = 1)) + expect_equal(result$prior, LogNormal(mean = 1, sd = 1)) expect_true(result$use_rt) expect_equal(result$rw, 0) expect_true(result$use_breakpoints) @@ -12,17 +12,17 @@ test_that("rt_opts returns expected default values", { }) test_that("rt_opts handles custom inputs correctly", { - result <- rt_opts( - prior = list(mean = 2, sd = 0.5), + result <- suppressWarnings(rt_opts( + prior = LogNormal(mean = 2, sd = 0.5), use_rt = FALSE, rw = 7, use_breakpoints = FALSE, future = "project", gp_on = "R0", pop = 1000000 - ) + )) - expect_equal(result$prior, list(mean = 2, sd = 0.5)) + expect_null(result$prior) expect_false(result$use_rt) expect_equal(result$rw, 7) expect_true(result$use_breakpoints) # Should be TRUE when rw > 0 @@ -37,10 +37,15 @@ test_that("rt_opts sets use_breakpoints to TRUE when rw > 0", { }) test_that("rt_opts throws error for invalid prior", { - expect_error(rt_opts(prior = list(mean = 1)), - "must have both") - expect_error(rt_opts(prior = list(sd = 1)), - "must have both") + ## deprecated + expect_error( + suppressWarnings(rt_opts(prior = list(mean = 1))), + "must have both" + ) + expect_error( + suppressWarnings(rt_opts(prior = list(sd = 1))), + "must have both" + ) }) test_that("rt_opts validates gp_on argument", { diff --git a/tests/testthat/test-simulate-infections.R b/tests/testthat/test-simulate-infections.R index 0314806de..3c977d689 100644 --- a/tests/testthat/test-simulate-infections.R +++ b/tests/testthat/test-simulate-infections.R @@ -30,7 +30,7 @@ test_that("simulate_infections works as expected with additional parameters", { sim <- test_simulate_infections( generation_time = gt_opts(fix_parameters(example_generation_time)), delays = delay_opts(fix_parameters(example_reporting_delay)), - obs = obs_opts(family = "negbin", phi = list(mean = 0.5, sd = 0)), + obs = obs_opts(family = "negbin", phi = Normal(mean = 0.5, sd = 0)), seeding_time = 10 ) expect_equal(nrow(sim), 2 * nrow(R)) @@ -49,7 +49,7 @@ test_that("simulate_infections fails with uncertain parameters", { expect_error( test_simulate_infections( generation_time = gt_opts(Fixed(1)), - obs = obs_opts(scale = list(mean = 1, sd = 1)) + obs = obs_opts(scale = Normal(mean = 1, sd = 1)) ), "uncertain" ) diff --git a/tests/testthat/test-simulate-secondary.R b/tests/testthat/test-simulate-secondary.R index f78c91de7..d30e7bcf3 100644 --- a/tests/testthat/test-simulate-secondary.R +++ b/tests/testthat/test-simulate-secondary.R @@ -22,7 +22,7 @@ test_that("simulate_secondary works as expected with additional parameters", { set.seed(123) sim <- test_simulate_secondary( delays = delay_opts(fix_parameters(example_reporting_delay)), - obs = obs_opts(family = "negbin", phi = list(mean = 0.5, sd = 0)) + obs = obs_opts(family = "negbin", phi = Fixed(0.5)) ) expect_equal(nrow(sim), nrow(cases)) expect_snapshot_output(sim) @@ -36,7 +36,7 @@ test_that("simulate_secondary fails with uncertain parameters", { ) expect_error( test_simulate_secondary( - obs = obs_opts(scale = list(mean = 1, sd = 1)) + obs = obs_opts(scale = Normal(mean = 1, sd = 1)) ), "uncertain" ) diff --git a/tests/testthat/test-stan-rt.R b/tests/testthat/test-stan-rt.R index 1b4c40153..7d5fcb17b 100644 --- a/tests/testthat/test-stan-rt.R +++ b/tests/testthat/test-stan-rt.R @@ -4,57 +4,57 @@ skip_on_os("windows") # Test update_Rt test_that("update_Rt works to produce multiple Rt estimates with a static gaussian process", { expect_equal( - update_Rt(10, log(1.2), rep(0, 9), rep(10, 0), numeric(0), 0), + update_Rt(10, 1.2, rep(0, 9), rep(10, 0), numeric(0), 0), rep(1.2, 10) ) }) test_that("update_Rt works to produce multiple Rt estimates with a non-static gaussian process", { expect_equal( - round(update_Rt(10, log(1.2), rep(0.1, 9), rep(10, 0), numeric(0), 0), 2), + round(update_Rt(10, 1.2, rep(0.1, 9), rep(10, 0), numeric(0), 0), 2), c(1.20, 1.33, 1.47, 1.62, 1.79, 1.98, 2.19, 2.42, 2.67, 2.95) ) }) test_that("update_Rt works to produce multiple Rt estimates with a non-static stationary gaussian process", { expect_equal( - round(update_Rt(10, log(1.2), rep(0.1, 10), rep(10, 0), numeric(0), 1), 3), + round(update_Rt(10, 1.2, rep(0.1, 10), rep(10, 0), numeric(0), 1), 3), c(1.326, 1.326, 1.326, 1.326, 1.326, 1.326, 1.326, 1.326, 1.326, 1.326) ) }) test_that("update_Rt works when Rt is fixed", { expect_equal( - round(update_Rt(10, log(1.2), numeric(0), rep(10, 0), numeric(0), 0), 2), + round(update_Rt(10, 1.2, numeric(0), rep(10, 0), numeric(0), 0), 2), rep(1.2, 10) ) expect_equal( - round(update_Rt(10, log(1.2), numeric(0), rep(10, 0), numeric(0), 1), 2), + round(update_Rt(10, 1.2, numeric(0), rep(10, 0), numeric(0), 1), 2), rep(1.2, 10) ) }) test_that("update_Rt works when Rt is fixed but a breakpoint is present", { expect_equal( - round(update_Rt(5, log(1.2), numeric(0), c(1, 1, 2, 2, 2), 0.1, 0), 2), + round(update_Rt(5, 1.2, numeric(0), c(1, 1, 2, 2, 2), 0.1, 0), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), numeric(0), c(1, 1, 2, 2, 2), 0.1, 1), 2), + round(update_Rt(5, 1.2, numeric(0), c(1, 1, 2, 2, 2), 0.1, 1), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), numeric(0), c(1, 2, 3, 3, 3), rep(0.1, 2), 0), 2), + round(update_Rt(5, 1.2, numeric(0), c(1, 2, 3, 3, 3), rep(0.1, 2), 0), 2), c(1.2, 1.33, rep(1.47, 3)) ) }) test_that("update_Rt works when Rt is variable and a breakpoint is present", { expect_equal( - round(update_Rt(5, log(1.2), rep(0, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2), + round(update_Rt(5, 1.2, rep(0, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), rep(0, 5), c(1, 1, 2, 2, 2), 0.1, 1), 2), + round(update_Rt(5, 1.2, rep(0, 5), c(1, 1, 2, 2, 2), 0.1, 1), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2), + round(update_Rt(5, 1.2, rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2), c(1.20, 1.33, 1.62, 1.79, 1.98) ) }) diff --git a/vignettes/EpiNow2.Rmd.orig b/vignettes/EpiNow2.Rmd.orig index 0b27fa285..69d10bcdd 100644 --- a/vignettes/EpiNow2.Rmd.orig +++ b/vignettes/EpiNow2.Rmd.orig @@ -94,7 +94,7 @@ estimates <- epinow( data = reported_cases, generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.2)), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.2)), stan = stan_opts(cores = 4, control = list(adapt_delta = 0.99)), verbose = interactive() ) @@ -148,7 +148,7 @@ estimates <- regional_epinow( data = reported_cases, generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.2), rw = 7), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.2), rw = 7), gp = NULL, stan = stan_opts(cores = 4, warmup = 250, samples = 1000) ) diff --git a/vignettes/epinow.Rmd.orig b/vignettes/epinow.Rmd.orig index e8da3777f..d8960bfbd 100644 --- a/vignettes/epinow.Rmd.orig +++ b/vignettes/epinow.Rmd.orig @@ -40,7 +40,7 @@ options(mc.cores = 4) reported_cases <- example_confirmed[1:60] reporting_delay <- LogNormal(mean = 2, sd = 1, max = 10) delay <- example_incubation_period + reporting_delay -rt_prior <- list(mean = 2, sd = 0.1) +rt_prior <- LogNormal(mean = 2, sd = 0.1) ``` We can then run the `epinow()` function with the same arguments as `estimate_infections()`. diff --git a/vignettes/estimate_infections_options.Rmd.orig b/vignettes/estimate_infections_options.Rmd.orig index bdfaaa138..25706703e 100644 --- a/vignettes/estimate_infections_options.Rmd.orig +++ b/vignettes/estimate_infections_options.Rmd.orig @@ -97,7 +97,7 @@ example_generation_time Lastly we need to choose a prior for the initial value of the reproduction number. This is assumed by the model to be normally distributed and we can set the mean and the standard deviation. We decide to set the mean to 2 and the standard deviation to 1. ```{r initial_r} -rt_prior <- list(mean = 2, sd = 0.1) +rt_prior <- LogNormal(mean = 2, sd = 0.1) ``` # Running the model diff --git a/vignettes/estimate_infections_workflow.Rmd.orig b/vignettes/estimate_infections_workflow.Rmd.orig index ca1297472..2c8f48a19 100644 --- a/vignettes/estimate_infections_workflow.Rmd.orig +++ b/vignettes/estimate_infections_workflow.Rmd.orig @@ -196,7 +196,7 @@ In _EpiNow2_ we can specify the proportion of infections that we expect to be ob For example, if we think that 40% (with standard deviation 1%) of infections end up in the data as observations we could specify. ```{r results = 'hide'} -obs_scale <- list(mean = 0.4, sd = 0.01) +obs_scale <- Normal(mean = 0.4, sd = 0.01) obs_opts(scale = obs_scale) ``` @@ -209,7 +209,7 @@ It can be changed using the `rt_opts()` function. For example, if the user believes that at the very start of the data the reproduction number was 2, with uncertainty in this belief represented by a standard deviation of 1, they would use ```{r results = 'hide'} -rt_prior <- list(mean = 2, sd = 1) +rt_prior <- LogNormal(mean = 2, sd = 1) rt_opts(prior = rt_prior) ```