diff --git a/R/create.R b/R/create.R index bf54f3612..c8c299503 100644 --- a/R/create.R +++ b/R/create.R @@ -685,10 +685,12 @@ create_stan_args <- function(stan = stan_opts(), ##' Create delay variables for stan ##' ##' @param ... Named delay distributions. The names are assigned to IDs -##' @param weight Numeric, weight associated with delay priors; default: 1 +##' @param time_points Integer, the number of time points in the data; +##' determines weight associated with weighted delay priors; default: 1 ##' @return A list of variables as expected by the stan model ##' @importFrom purrr transpose map flatten -create_stan_delays <- function(..., weight = 1) { +create_stan_delays <- function(..., time_points = 1L) { + delays <- list(...) ## discretise delays <- map(list(...), discretise) ## convolve where appropriate @@ -706,23 +708,23 @@ create_stan_delays <- function(..., weight = 1) { ids[type_n > 0] <- seq_len(sum(type_n > 0)) names(ids) <- paste(names(type_n), "id", sep = "_") - delays <- flatten(delays) - parametric <- unname( - vapply(delays, function(x) x$distribution != "nonparametric", logical(1)) - ) - param_length <- unname(vapply(delays[parametric], function(x) { + flat_delays <- flatten(delays) + parametric <- unname(vapply( + flat_delays, function(x) x$distribution != "nonparametric", logical(1) + )) + param_length <- unname(vapply(flat_delays[parametric], function(x) { length(x$parameters) }, numeric(1))) - nonparam_length <- unname(vapply(delays[!parametric], function(x) { + nonparam_length <- unname(vapply(flat_delays[!parametric], function(x) { length(x$pmf) }, numeric(1))) distributions <- unname(as.character( - map(delays[parametric], ~ .x$distribution) + map(flat_delays[parametric], ~ .x$distribution) )) ## create stan object ret <- list( - n = length(delays), + n = length(flat_delays), n_p = sum(parametric), n_np = sum(!parametric), types = sum(type_n > 0), @@ -738,15 +740,15 @@ create_stan_delays <- function(..., weight = 1) { ret$types_groups <- array(c(0, cumsum(unname(type_n[type_n > 0]))) + 1) ret$params_mean <- array(unname(as.numeric( - map(flatten(map(delays[parametric], ~ .x$parameters)), mean) + map(flatten(map(flat_delays[parametric], ~ .x$parameters)), mean) ))) ret$params_sd <- array(unname(as.numeric( - map(flatten(map(delays[parametric], ~ .x$parameters)), sd_dist) + map(flatten(map(flat_delays[parametric], ~ .x$parameters)), sd_dist) ))) ret$max <- array(max_delay[parametric]) ret$np_pmf <- array(unname(as.numeric( - flatten(map(delays[!parametric], ~ .x$pmf)) + flatten(map(flat_delays[!parametric], ~ .x$pmf)) ))) ## get non zero length delay pmf lengths ret$np_pmf_groups <- array(c(0, cumsum(nonparam_length)) + 1) @@ -758,12 +760,16 @@ create_stan_delays <- function(..., weight = 1) { ret$params_length <- sum(param_length) ## set lower bounds ret$params_lower <- array(unname(as.numeric(flatten( - map(delays[parametric], function(x) { + map(flat_delays[parametric], function(x) { lower_bounds(x$distribution)[names(x$parameters)] }) )))) ## assign prior weights - ret$weight <- array(rep(weight, ret$n_p)) + weigh_priors <- vapply( + delays[parametric], attr, "weigh_prior", FUN.VALUE = logical(1) + ) + ret$weight <- array(rep(1, ret$n_p)) + ret$weight[weigh_priors] <- time_points ## assign distribution ret$dist <- array(match(distributions, c("lognormal", "gamma")) - 1L) diff --git a/R/dist_spec.R b/R/dist_spec.R index a1d1645b8..45526e0b5 100644 --- a/R/dist_spec.R +++ b/R/dist_spec.R @@ -472,7 +472,8 @@ discretise <- function(x, silent = TRUE) { } } }) - attr(ret, "class") <- c("dist_spec", "list") + ## preserve attributes + attributes(ret) <- attributes(x) return(ret) } #' @rdname discretise @@ -543,7 +544,7 @@ apply_tolerance <- function(x, tolerance) { if (!is(x, "dist_spec")) { stop("Can only apply tolerance to distributions in a .") } - x <- lapply(x, function(x) { + y <- lapply(x, function(x) { if (x$distribution == "nonparametric") { cmf <- cumsum(x$pmf) new_pmf <- x$pmf[c(TRUE, (1 - cmf[-length(cmf)]) >= tolerance)] @@ -554,8 +555,9 @@ apply_tolerance <- function(x, tolerance) { } }) - attr(x, "class") <- c("dist_spec", "list") - return(x) + ## preserve attributes + attributes(y) <- attributes(x) + return(y) } #' Prints the parameters of one or more delay distributions diff --git a/R/estimate_infections.R b/R/estimate_infections.R index b69390ed9..44345c9ca 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -203,9 +203,7 @@ estimate_infections <- function(reported_cases, gt = generation_time, delay = delays, trunc = truncation, - weight = ifelse( - weigh_delay_priors, data$t - data$seeding_time - data$horizon, 1 - ) + time_points = data$t - data$seeding_time - data$horizon )) # Set up default settings diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 7c02611d9..6eb9658cc 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -140,7 +140,7 @@ estimate_secondary <- function(reports, meanlog = Normal(2.5, 0.5), sdlog = Normal(0.47, 0.25), max = 30 - ) + ), weigh_prior = FALSE ), truncation = trunc_opts(), obs = obs_opts(), @@ -198,7 +198,7 @@ estimate_secondary <- function(reports, data <- c(data, create_stan_delays( delay = delays, trunc = truncation, - weight = ifelse(weigh_delay_priors, data$t, 1) + time_points = data$t )) # observation model data diff --git a/R/estimate_truncation.R b/R/estimate_truncation.R index f42422620..7aedb5d58 100644 --- a/R/estimate_truncation.R +++ b/R/estimate_truncation.R @@ -47,12 +47,8 @@ #' @param model A compiled stan model to override the default model. May be #' useful for package developers or those developing extensions. #' -#' @param weigh_delay_priors Logical. If TRUE, all delay distribution priors -#' will be weighted by the number of observation data points, in doing so -#' approximately placing an independent prior at each time step and usually -#' preventing the posteriors from shifting. If FALSE (default), no weight will -#' be applied, i.e. delay distributions will be treated as a single -#' parameters. +#' @param weigh_delay_priors Deprecated; use the `weigh_priors` option in +#' [trunc_opts()] instead. #' #' @param verbose Logical, should model fitting progress be returned. #' @@ -121,7 +117,15 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10, "estimate_truncation(stan)" ) } - # Validate inputs + if (!missing(weigh_delay_priors)) { + lifecycle::deprecate_warn( + "1.5.0", + "estimate_truncation(weigh_delay_priors)", + "trunc_opts(weigh_prior)", + detail = "This argument will be removed completely in version 2.0.0" + ) + } + # Validate inputs walk(obs, check_reports_valid, model = "estimate_truncation") assert_class(truncation, "dist_spec") assert_class(model, "stanfit", null.ok = TRUE) @@ -216,7 +220,7 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10, data <- c(data, create_stan_delays( trunc = truncation, - weight = ifelse(weigh_delay_priors, data$t, 1) + time_points = data$t )) # initial conditions diff --git a/R/opts.R b/R/opts.R index 4d0aaae1b..f340e70ae 100644 --- a/R/opts.R +++ b/R/opts.R @@ -12,8 +12,13 @@ #' @param max deprecated; use `dist` instead #' @param fixed deprecated; use `dist` instead #' @param prior_weight deprecated; prior weights are now specified as a -#' model option. Use the `weigh_delay_priors` argument of -#' [estimate_infections()] instead. +#' model option. Use the `weigh_prior` argument instead +#' @param weigh_prior Logical; if TRUE (default), the generation time prior +#' will be weighted by the number of observation data points, in doing so +#' approximately placing an independent prior at each time step and usually +#' preventing the posteriors from shifting. If FALSE , no weight will be +#' applied, i.e. the generation time distribution will be treated as a single +#' parameter. #' @inheritParams apply_tolerance #' @return A `` object summarising the input delay #' distributions. @@ -40,7 +45,8 @@ #' generation_time_opts(example_generation_time) generation_time_opts <- function(dist = Fixed(1), ..., disease, source, max = 14, fixed = FALSE, - prior_weight, tolerance = 0.001) { + prior_weight, tolerance = 0.001, + weigh_prior = TRUE) { deprecated_options_given <- FALSE dot_options <- list(...) @@ -82,7 +88,7 @@ generation_time_opts <- function(dist = Fixed(1), ..., if (!missing(prior_weight)) { deprecate_warn( "1.4.0", "generation_time_opts(prior_weight)", - "estimate_infections(weigh_delay_prior)", + "generation_time_opts(weigh_prior)", "This argument will be removed in version 2.0.0." ) } @@ -107,6 +113,7 @@ generation_time_opts <- function(dist = Fixed(1), ..., } check_stan_delay(dist) attr(dist, "tolerance") <- tolerance + attr(dist, "weigh_prior") <- weigh_prior attr(dist, "class") <- c("generation_time_opts", class(dist)) return(dist) } @@ -189,6 +196,7 @@ secondary_opts <- function(type = c("incidence", "prevalence"), ...) { #' @param ... deprecated; use `dist` instead #' @param fixed deprecated; use `dist` instead #' @inheritParams apply_tolerance +#' @inheritParams generation_time_opts #' @return A `` object summarising the input delay distributions. #' @seealso [convert_to_logmean()] [convert_to_logsd()] #' [bootstrapped_dist_fit()] [dist_spec()] @@ -207,7 +215,8 @@ secondary_opts <- function(type = c("incidence", "prevalence"), ...) { #' #' # Multiple delays (in this case twice the same) #' delay_opts(delay + delay) -delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) { +delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001, + weigh_prior = TRUE) { dot_options <- list(...) if (!is(dist, "dist_spec")) { ## could be old syntax if (is.list(dist)) { @@ -240,6 +249,7 @@ delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) { } check_stan_delay(dist) attr(dist, "tolerance") <- tolerance + attr(dist, "weigh_prior") <- weigh_prior attr(dist, "class") <- c("delay_opts", class(dist)) return(dist) } @@ -254,6 +264,12 @@ delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) { #' @param dist A delay distribution or series of delay distributions reflecting #' the truncation generated using [dist_spec()] or [estimate_truncation()]. #' Default is fixed distribution with maximum 0, i.e. no truncation +#' @param weigh_prior Logical; if TRUE, the truncation prior will be weighted +#' by the number of observation data points, in doing so approximately placing +#' an independent prior at each time step and usually preventing the +#' posteriors from shifting. If FALSE (default), no weight will be applied, +#' i.e. the truncation distribution will be treated as a single parameter. +#' #' @inheritParams apply_tolerance #' @return A `` object summarising the input truncation #' distribution. @@ -267,7 +283,8 @@ delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) { #' #' # truncation dist #' trunc_opts(dist = LogNormal(mean = 3, sd = 2, max = 10)) -trunc_opts <- function(dist = Fixed(0), tolerance = 0.001) { +trunc_opts <- function(dist = Fixed(0), tolerance = 0.001, + weigh_prior = FALSE) { if (!is(dist, "dist_spec")) { if (is.list(dist)) { dist <- do.call(dist_spec, dist) @@ -285,6 +302,7 @@ trunc_opts <- function(dist = Fixed(0), tolerance = 0.001) { } check_stan_delay(dist) attr(dist, "tolerance") <- tolerance + attr(dist, "weigh_prior") <- weigh_prior attr(dist, "class") <- c("trunc_opts", class(dist)) return(dist) } diff --git a/tests/testthat/test-delays.R b/tests/testthat/test-delays.R index 6cb415724..639d584ff 100644 --- a/tests/testthat/test-delays.R +++ b/tests/testthat/test-delays.R @@ -6,7 +6,7 @@ test_stan_delays <- function(generation_time = generation_time_opts(Fixed(1)), generation_time = generation_time, delays = delays, truncation = truncation, - weight = 10 + time_points = 10 ) return(unlist(unname(data[params]))) } diff --git a/tests/testthat/test-epinow.R b/tests/testthat/test-epinow.R index 83dd87c69..f9b146259 100644 --- a/tests/testthat/test-epinow.R +++ b/tests/testthat/test-epinow.R @@ -84,7 +84,9 @@ test_that("epinow runs without error when saving to disk", { test_that("epinow can produce partial output as specified", { out <- suppressWarnings(epinow( reported_cases = reported_cases, - generation_time = generation_time_opts(example_generation_time), + generation_time = generation_time_opts( + example_generation_time, weigh_prior = FALSE + ), delays = delay_opts(example_incubation_period + reporting_delay), stan = stan_opts( samples = 25, warmup = 25,