Skip to content

Commit

Permalink
weigh delays separately
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Mar 22, 2024
1 parent e9eab37 commit 0b07718
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 40 deletions.
36 changes: 21 additions & 15 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -685,10 +685,12 @@ create_stan_args <- function(stan = stan_opts(),
##' Create delay variables for stan
##'
##' @param ... Named delay distributions. The names are assigned to IDs
##' @param weight Numeric, weight associated with delay priors; default: 1
##' @param time_points Integer, the number of time points in the data;
##' determines weight associated with weighted delay priors; default: 1
##' @return A list of variables as expected by the stan model
##' @importFrom purrr transpose map flatten
create_stan_delays <- function(..., weight = 1) {
create_stan_delays <- function(..., time_points = 1L) {
delays <- list(...)
## discretise
delays <- map(list(...), discretise)
## convolve where appropriate
Expand All @@ -706,23 +708,23 @@ create_stan_delays <- function(..., weight = 1) {
ids[type_n > 0] <- seq_len(sum(type_n > 0))
names(ids) <- paste(names(type_n), "id", sep = "_")

delays <- flatten(delays)
parametric <- unname(
vapply(delays, function(x) x$distribution != "nonparametric", logical(1))
)
param_length <- unname(vapply(delays[parametric], function(x) {
flat_delays <- flatten(delays)
parametric <- unname(vapply(
flat_delays, function(x) x$distribution != "nonparametric", logical(1)
))
param_length <- unname(vapply(flat_delays[parametric], function(x) {
length(x$parameters)
}, numeric(1)))
nonparam_length <- unname(vapply(delays[!parametric], function(x) {
nonparam_length <- unname(vapply(flat_delays[!parametric], function(x) {
length(x$pmf)
}, numeric(1)))
distributions <- unname(as.character(
map(delays[parametric], ~ .x$distribution)
map(flat_delays[parametric], ~ .x$distribution)
))

## create stan object
ret <- list(
n = length(delays),
n = length(flat_delays),
n_p = sum(parametric),
n_np = sum(!parametric),
types = sum(type_n > 0),
Expand All @@ -738,15 +740,15 @@ create_stan_delays <- function(..., weight = 1) {
ret$types_groups <- array(c(0, cumsum(unname(type_n[type_n > 0]))) + 1)

ret$params_mean <- array(unname(as.numeric(
map(flatten(map(delays[parametric], ~ .x$parameters)), mean)
map(flatten(map(flat_delays[parametric], ~ .x$parameters)), mean)
)))
ret$params_sd <- array(unname(as.numeric(
map(flatten(map(delays[parametric], ~ .x$parameters)), sd_dist)
map(flatten(map(flat_delays[parametric], ~ .x$parameters)), sd_dist)
)))
ret$max <- array(max_delay[parametric])

ret$np_pmf <- array(unname(as.numeric(
flatten(map(delays[!parametric], ~ .x$pmf))
flatten(map(flat_delays[!parametric], ~ .x$pmf))
)))
## get non zero length delay pmf lengths
ret$np_pmf_groups <- array(c(0, cumsum(nonparam_length)) + 1)
Expand All @@ -758,12 +760,16 @@ create_stan_delays <- function(..., weight = 1) {
ret$params_length <- sum(param_length)
## set lower bounds
ret$params_lower <- array(unname(as.numeric(flatten(
map(delays[parametric], function(x) {
map(flat_delays[parametric], function(x) {
lower_bounds(x$distribution)[names(x$parameters)]
})
))))
## assign prior weights
ret$weight <- array(rep(weight, ret$n_p))
weigh_priors <- vapply(
delays[parametric], attr, "weigh_prior", FUN.VALUE = logical(1)
)
ret$weight <- array(rep(1, ret$n_p))
ret$weight[weigh_priors] <- time_points
## assign distribution
ret$dist <- array(match(distributions, c("lognormal", "gamma")) - 1L)

Expand Down
10 changes: 6 additions & 4 deletions R/dist_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,8 @@ discretise <- function(x, silent = TRUE) {
}
}
})
attr(ret, "class") <- c("dist_spec", "list")
## preserve attributes
attributes(ret) <- attributes(x)
return(ret)
}
#' @rdname discretise
Expand Down Expand Up @@ -543,7 +544,7 @@ apply_tolerance <- function(x, tolerance) {
if (!is(x, "dist_spec")) {
stop("Can only apply tolerance to distributions in a <dist_spec>.")
}
x <- lapply(x, function(x) {
y <- lapply(x, function(x) {
if (x$distribution == "nonparametric") {
cmf <- cumsum(x$pmf)
new_pmf <- x$pmf[c(TRUE, (1 - cmf[-length(cmf)]) >= tolerance)]
Expand All @@ -554,8 +555,9 @@ apply_tolerance <- function(x, tolerance) {
}
})

attr(x, "class") <- c("dist_spec", "list")
return(x)
## preserve attributes
attributes(y) <- attributes(x)
return(y)
}

#' Prints the parameters of one or more delay distributions
Expand Down
4 changes: 1 addition & 3 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ estimate_infections <- function(reported_cases,
gt = generation_time,
delay = delays,
trunc = truncation,
weight = ifelse(
weigh_delay_priors, data$t - data$seeding_time - data$horizon, 1
)
time_points = data$t - data$seeding_time - data$horizon
))

# Set up default settings
Expand Down
4 changes: 2 additions & 2 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ estimate_secondary <- function(reports,
meanlog = Normal(2.5, 0.5),
sdlog = Normal(0.47, 0.25),
max = 30
)
), weigh_prior = FALSE
),
truncation = trunc_opts(),
obs = obs_opts(),
Expand Down Expand Up @@ -198,7 +198,7 @@ estimate_secondary <- function(reports,
data <- c(data, create_stan_delays(
delay = delays,
trunc = truncation,
weight = ifelse(weigh_delay_priors, data$t, 1)
time_points = data$t
))

# observation model data
Expand Down
20 changes: 12 additions & 8 deletions R/estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,8 @@
#' @param model A compiled stan model to override the default model. May be
#' useful for package developers or those developing extensions.
#'
#' @param weigh_delay_priors Logical. If TRUE, all delay distribution priors
#' will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE (default), no weight will
#' be applied, i.e. delay distributions will be treated as a single
#' parameters.
#' @param weigh_delay_priors Deprecated; use the `weigh_priors` option in
#' [trunc_opts()] instead.
#'
#' @param verbose Logical, should model fitting progress be returned.
#'
Expand Down Expand Up @@ -121,7 +117,15 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,
"estimate_truncation(stan)"
)
}
# Validate inputs
if (!missing(weigh_delay_priors)) {
lifecycle::deprecate_warn(
"1.5.0",
"estimate_truncation(weigh_delay_priors)",
"trunc_opts(weigh_prior)",
detail = "This argument will be removed completely in version 2.0.0"
)
}
# Validate inputs
walk(obs, check_reports_valid, model = "estimate_truncation")
assert_class(truncation, "dist_spec")
assert_class(model, "stanfit", null.ok = TRUE)
Expand Down Expand Up @@ -216,7 +220,7 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,

data <- c(data, create_stan_delays(
trunc = truncation,
weight = ifelse(weigh_delay_priors, data$t, 1)
time_points = data$t
))

# initial conditions
Expand Down
30 changes: 24 additions & 6 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
#' @param max deprecated; use `dist` instead
#' @param fixed deprecated; use `dist` instead
#' @param prior_weight deprecated; prior weights are now specified as a
#' model option. Use the `weigh_delay_priors` argument of
#' [estimate_infections()] instead.
#' model option. Use the `weigh_prior` argument instead
#' @param weigh_prior Logical; if TRUE (default), the generation time prior
#' will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE , no weight will be
#' applied, i.e. the generation time distribution will be treated as a single
#' parameter.
#' @inheritParams apply_tolerance
#' @return A `<generation_time_opts>` object summarising the input delay
#' distributions.
Expand All @@ -40,7 +45,8 @@
#' generation_time_opts(example_generation_time)
generation_time_opts <- function(dist = Fixed(1), ...,
disease, source, max = 14, fixed = FALSE,
prior_weight, tolerance = 0.001) {
prior_weight, tolerance = 0.001,
weigh_prior = TRUE) {
deprecated_options_given <- FALSE
dot_options <- list(...)

Expand Down Expand Up @@ -82,7 +88,7 @@ generation_time_opts <- function(dist = Fixed(1), ...,
if (!missing(prior_weight)) {
deprecate_warn(
"1.4.0", "generation_time_opts(prior_weight)",
"estimate_infections(weigh_delay_prior)",
"generation_time_opts(weigh_prior)",
"This argument will be removed in version 2.0.0."
)
}
Expand All @@ -107,6 +113,7 @@ generation_time_opts <- function(dist = Fixed(1), ...,
}
check_stan_delay(dist)
attr(dist, "tolerance") <- tolerance
attr(dist, "weigh_prior") <- weigh_prior
attr(dist, "class") <- c("generation_time_opts", class(dist))
return(dist)
}
Expand Down Expand Up @@ -189,6 +196,7 @@ secondary_opts <- function(type = c("incidence", "prevalence"), ...) {
#' @param ... deprecated; use `dist` instead
#' @param fixed deprecated; use `dist` instead
#' @inheritParams apply_tolerance
#' @inheritParams generation_time_opts
#' @return A `<delay_opts>` object summarising the input delay distributions.
#' @seealso [convert_to_logmean()] [convert_to_logsd()]
#' [bootstrapped_dist_fit()] [dist_spec()]
Expand All @@ -207,7 +215,8 @@ secondary_opts <- function(type = c("incidence", "prevalence"), ...) {
#'
#' # Multiple delays (in this case twice the same)
#' delay_opts(delay + delay)
delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) {
delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001,
weigh_prior = TRUE) {
dot_options <- list(...)
if (!is(dist, "dist_spec")) { ## could be old syntax
if (is.list(dist)) {
Expand Down Expand Up @@ -240,6 +249,7 @@ delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) {
}
check_stan_delay(dist)
attr(dist, "tolerance") <- tolerance
attr(dist, "weigh_prior") <- weigh_prior
attr(dist, "class") <- c("delay_opts", class(dist))
return(dist)
}
Expand All @@ -254,6 +264,12 @@ delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) {
#' @param dist A delay distribution or series of delay distributions reflecting
#' the truncation generated using [dist_spec()] or [estimate_truncation()].
#' Default is fixed distribution with maximum 0, i.e. no truncation
#' @param weigh_prior Logical; if TRUE, the truncation prior will be weighted
#' by the number of observation data points, in doing so approximately placing
#' an independent prior at each time step and usually preventing the
#' posteriors from shifting. If FALSE (default), no weight will be applied,
#' i.e. the truncation distribution will be treated as a single parameter.
#'
#' @inheritParams apply_tolerance
#' @return A `<trunc_opts>` object summarising the input truncation
#' distribution.
Expand All @@ -267,7 +283,8 @@ delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) {
#'
#' # truncation dist
#' trunc_opts(dist = LogNormal(mean = 3, sd = 2, max = 10))
trunc_opts <- function(dist = Fixed(0), tolerance = 0.001) {
trunc_opts <- function(dist = Fixed(0), tolerance = 0.001,
weigh_prior = FALSE) {
if (!is(dist, "dist_spec")) {
if (is.list(dist)) {
dist <- do.call(dist_spec, dist)
Expand All @@ -285,6 +302,7 @@ trunc_opts <- function(dist = Fixed(0), tolerance = 0.001) {
}
check_stan_delay(dist)
attr(dist, "tolerance") <- tolerance
attr(dist, "weigh_prior") <- weigh_prior
attr(dist, "class") <- c("trunc_opts", class(dist))
return(dist)
}
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-delays.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test_stan_delays <- function(generation_time = generation_time_opts(Fixed(1)),
generation_time = generation_time,
delays = delays,
truncation = truncation,
weight = 10
time_points = 10
)
return(unlist(unname(data[params])))
}
Expand Down
4 changes: 3 additions & 1 deletion tests/testthat/test-epinow.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ test_that("epinow runs without error when saving to disk", {
test_that("epinow can produce partial output as specified", {
out <- suppressWarnings(epinow(
reported_cases = reported_cases,
generation_time = generation_time_opts(example_generation_time),
generation_time = generation_time_opts(
example_generation_time, weigh_prior = FALSE
),
delays = delay_opts(example_incubation_period + reporting_delay),
stan = stan_opts(
samples = 25, warmup = 25,
Expand Down

0 comments on commit 0b07718

Please sign in to comment.