Skip to content

Commit

Permalink
Change all priors to use <dist_spec> (#871)
Browse files Browse the repository at this point in the history
* new parameter interface in stan code

* adapt R code to new param interface

* render docs

* update tests

* update examples and other code snippets

* add news item

* add progressr to lintr workflow

* switch benchmarks back to previous syntax

otherwise they won't work on main

* change benchmark back

it won't work anyway as the R code has changed too much

* bumping all vignettes

* add reviewer

Co-authored-by: Sam Abbott <[email protected]>

* cherry pick vignettes from main

* make all priors consistent with previous versions

---------

Co-authored-by: Sam <[email protected]>
  • Loading branch information
sbfnk and seabbs authored Dec 6, 2024
1 parent b7be177 commit 9d733d2
Show file tree
Hide file tree
Showing 57 changed files with 809 additions and 450 deletions.
1 change: 1 addition & 0 deletions .github/workflows/lint-only-changed-files.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
any::gh
any::lintr
any::purrr
progressr
- name: Add lintr options
run: |
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Generated by roxygen2: do not edit by hand

S3method("!=",dist_spec)
S3method("+",dist_spec)
S3method("==",dist_spec)
S3method(c,dist_spec)
S3method(collapse,dist_spec)
S3method(collapse,multi_dist_spec)
Expand Down
5 changes: 3 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
estimate_infections()
```

- A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs.
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs.
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and reviewed by @seabbs.

## Documentation

Expand Down
166 changes: 121 additions & 45 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,6 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL,

# map settings to underlying gp stan requirements
rt_data <- list(
r_mean = rt$prior$mean,
r_sd = rt$prior$sd,
estimate_r = as.numeric(rt$use_rt),
bp_n = ifelse(rt$use_breakpoints, max(breakpoints) - 1, 0),
breakpoints = breakpoints,
Expand Down Expand Up @@ -433,8 +431,6 @@ create_gp_data <- function(gp = gp_opts(), data) {
ls_sdlog = convert_to_logsd(gp$ls_mean, gp$ls_sd),
ls_min = gp$ls_min,
ls_max = gp$ls_max,
alpha_mean = gp$alpha_mean,
alpha_sd = gp$alpha_sd,
gp_type = data.table::fcase(
gp$kernel == "se", 0,
gp$kernel == "periodic", 1,
Expand Down Expand Up @@ -472,7 +468,7 @@ create_gp_data <- function(gp = gp_opts(), data) {
#'
#' # Applying a observation scaling to the data
#' create_obs_model(
#' obs_opts(scale = list(mean = 0.4, sd = 0.01)), dates = dates
#' obs_opts(scale = Normal(mean = 0.4, sd = 0.01)), dates = dates
#' )
#'
#' # Apply a custom week week length
Expand All @@ -481,13 +477,9 @@ create_gp_data <- function(gp = gp_opts(), data) {
create_obs_model <- function(obs = obs_opts(), dates) {
data <- list(
model_type = as.numeric(obs$family == "negbin"),
phi_mean = obs$phi$mean,
phi_sd = obs$phi$sd,
week_effect = ifelse(obs$week_effect, obs$week_length, 1),
obs_weight = obs$weight,
obs_scale = as.integer(obs$scale$sd > 0 || obs$scale$mean != 1),
obs_scale_mean = obs$scale$mean,
obs_scale_sd = obs$scale$sd,
obs_scale = as.integer(obs$scale != Fixed(1)),
likelihood = as.numeric(obs$likelihood),
return_likelihood = as.numeric(obs$return_likelihood)
)
Expand Down Expand Up @@ -589,15 +581,30 @@ create_stan_data <- function(data, seeding_time,
)
)

# parameters
stan_data <- c(
stan_data,
create_stan_params(
alpha = gp$alpha,
R0 = rt$prior,
frac_obs = obs$scale,
rep_phi = obs$phi,
lower_bounds = c(
alpha = 0,
R0 = 0,
frac_obs = 0,
rep_phi = 0
)
)
)

# rescale mean shifted prior for back calculation if observation scaling is
# used
if (stan_data$obs_scale == 1) {
stan_data$shifted_cases <-
stan_data$shifted_cases / stan_data$obs_scale_mean
stan_data$prior_infections <- log(
exp(stan_data$prior_infections) / stan_data$obs_scale_mean
)
}
stan_data$shifted_cases <-
stan_data$shifted_cases / mean(obs$scale)
stan_data$prior_infections <- log(
exp(stan_data$prior_infections) / mean(obs$scale)
)
return(stan_data)
}

Expand Down Expand Up @@ -647,34 +654,15 @@ create_initial_conditions <- function(data) {
out$rescaled_rho < data$ls_min, data$ls_min + 0.001,
default = out$rescaled_rho
))

out$alpha <- array(
truncnorm::rtruncnorm(
1, a = 0, mean = data$alpha_mean, sd = data$alpha_sd
)
)
} else {
out$eta <- array(numeric(0))
out$rescaled_rho <- array(numeric(0))
out$alpha <- array(numeric(0))
}
if (data$model_type == 1) {
out$rep_phi <- array(
truncnorm::rtruncnorm(
1,
a = 0, mean = data$phi_mean, sd = data$phi_sd
)
)
}
if (data$estimate_r == 1) {
out$initial_infections <- array(rnorm(1, data$prior_infections, 0.2))
if (data$seeding_time > 1) {
out$initial_growth <- array(rnorm(1, data$prior_growth, 0.02))
}
out$log_R <- array(rnorm(
n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd),
sd = convert_to_logsd(data$r_mean, data$r_sd)
))
}

if (data$bp_n > 0) {
Expand All @@ -684,20 +672,17 @@ create_initial_conditions <- function(data) {
out$bp_sd <- array(numeric(0))
out$bp_effects <- array(numeric(0))
}
if (data$obs_scale_sd > 0) {
out$frac_obs <- array(truncnorm::rtruncnorm(1,
a = 0, b = 1,
mean = data$obs_scale_mean,
sd = data$obs_scale_sd
))
} else {
out$frac_obs <- array(numeric(0))
}
if (data$week_effect > 0) {
out$day_of_week_simplex <- array(
rep(1 / data$week_effect, data$week_effect)
)
}
out$params <- array(truncnorm::rtruncnorm(
data$n_params_variable,
a = data$params_lower,
b = data$params_upper,
mean = 0, sd = 1
))
return(out)
}
return(init_fun)
Expand Down Expand Up @@ -877,3 +862,94 @@ create_stan_delays <- function(..., time_points = 1L) {

return(ret)
}

##' Create parameters for stan
##'
##' @param ... Named delay distributions. The names are assigned to IDs
##' @param lower_bounds Named vector of lower bounds for any delay(s). The names
##' have to correspond to the names given to the delay distributions passed.
##' If `NULL` (default) no parameters are given a lower bound.
##' @return A list of variables as expected by the stan model
##' @importFrom data.table fcase
##' @keywords internal
create_stan_params <- function(..., lower_bounds = NULL) {
params <- list(...)

## set IDs of any parameters that is NULL to 0 and remove
null_params <- vapply(params, is.null, logical(1))
null_ids <- rep(0, sum(null_params))
if (length(null_ids) > 0) {
names(null_ids) <- paste(names(null_params)[null_params], "id", sep = "_")
params <- params[!null_params]
}

## initialise variables
params_fixed_lookup <- rep(0L, length(params))
params_variable_lookup <- rep(0L, length(params))

## identify fixed/variable parameters
fixed <- vapply(params, get_distribution, character(1)) == "fixed"
params_fixed_lookup[fixed] <- seq_along(which(fixed))
params_variable_lookup[!fixed] <- seq_along(which(!fixed))

## lower bounds
params_lower <- rep(-Inf, length(params[!fixed]))
names(params_lower) <- names(params[!fixed])
lower_bounds <- lower_bounds[names(params_lower)]
params_lower[names(lower_bounds)] <- lower_bounds

## upper bounds
params_upper <- vapply(params[!fixed], max, numeric(1))

## prior distributions
prior_dist_name <- vapply(params[!fixed], get_distribution, character(1))
prior_dist <- fcase(
prior_dist_name == "lognormal", 0L,
prior_dist_name == "gamma", 1L,
prior_dist_name == "normal", 2L
)
## parameters
prior_dist_params <- lapply(params[!fixed], get_parameters)
prior_dist_params_lengths <- lengths(prior_dist_params)

## check none of the parameters are uncertain
prior_uncertain <- vapply(prior_dist_params, function(x) {
!all(vapply(x, is.numeric, logical(1)))
}, logical(1))
if (any(prior_uncertain)) {
uncertain_priors <- names(params[!fixed])[prior_uncertain] # nolint: object_usage_linter
cli_abort(
c(
"!" = "Parameter prior distribution{?s} for {.var {uncertain_priors}}
cannot have uncertain parameters."
)
)
}

prior_dist_params <- unlist(prior_dist_params)
if (is.null(prior_dist_params)) {
prior_dist_params <- numeric(0)
}

## extract distributions and parameters
ret <- list(
n_params_variable = length(params) - sum(fixed),
n_params_fixed = sum(fixed),
params_lower = array(params_lower),
params_upper = array(params_upper),
params_fixed_lookup = array(params_fixed_lookup),
params_variable_lookup = array(params_variable_lookup),
params_value = array(vapply(
params[fixed], \(x) get_parameters(x)$value, numeric(1)
)),
prior_dist = array(prior_dist),
prior_dist_params_length = sum(prior_dist_params_lengths),
prior_dist_params = array(prior_dist_params)
)
ids <- seq_along(params)
if (length(ids) > 0) {
names(ids) <- paste(names(params), "id", sep = "_")
}
ret <- c(ret, as.list(ids), as.list(null_ids))
return(ret)
}
51 changes: 51 additions & 0 deletions R/dist_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,57 @@ discrete_pmf <- function(distribution =
c(e1, e2)
}

##' Compares two delay distributions
##'
##' @param e1 The first delay distribution (of type <dist_spec>) to
##' combine.
##'
##' @param e2 The second delay distribution (of type <dist_spec>) to
##' combine.
##' @method == dist_spec
##' @return TRUE or FALSE
##' @export
##' @examples
##' Fixed(1) == Normal(1, 0.5)
## nolint start: cyclocomp_linter
`==.dist_spec` <- function(e1, e2) {
## both must have same number of distributions
if (ndist(e1) != ndist(e2)) return(FALSE)
## loop over constituent distributions
for (i in seq_len(ndist(e1))) {
## distributions need to be the same
if (get_distribution(e1, i) != get_distribution(e2, i)) return(FALSE)
if (get_distribution(e1, i) == "nonparametric") {
## if nonparametric then PMFs need to be the same
if (!identical(get_pmf(e1, i), get_pmf(e2, i))) return(FALSE)
} else {
## if parametric then all parameters need to be the same
params1 <- get_parameters(e1, i)
params2 <- get_parameters(e2, i)
for (param in names(params1)) {
## all parameters must be the same type
if ((is(params1[[param]], "dist_spec") &&
is(params2[[param]], "dist_spec")) ||
(is.numeric(params1[[param]]) && is.numeric(params2[[param]]))) {
## if parameters are the same type they need to be same value
if (!(params1[[param]] == params2[[param]])) return(FALSE)
} else {
return(FALSE)
}
}
}
}
return(TRUE)
}
## nolint end: cyclocomp_linter

##' @rdname equals-.dist_spec
##' @method != dist_spec
##' @export
`!=.dist_spec` <- function(e1, e2) {
!(e1 == e2)
}

#' Combines multiple delay distributions for further processing
#'
#' @description `r lifecycle::badge("experimental")`
Expand Down
2 changes: 1 addition & 1 deletion R/epinow.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
#' out <- epinow(
#' data = reported_cases,
#' generation_time = gt_opts(generation_time),
#' rt = rt_opts(prior = list(mean = 2, sd = 0.1)),
#' rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1)),
#' delays = delay_opts(incubation_period + reporting_delay)
#' )
#' # summary of the latest estimates
Expand Down
2 changes: 1 addition & 1 deletion R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
#' def <- estimate_infections(reported_cases,
#' generation_time = gt_opts(generation_time),
#' delays = delay_opts(incubation_period + reporting_delay),
#' rt = rt_opts(prior = list(mean = 2, sd = 0.1))
#' rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1))
#' )
#' # real time estimates
#' summary(def)
Expand Down
15 changes: 12 additions & 3 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
#' # fit model to example data specifying a weak prior for fraction reported
#' # with a secondary case
#' inc <- estimate_secondary(cases[1:60],
#' obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE)
#' obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE)
#' )
#' plot(inc, primary = TRUE)
#'
Expand Down Expand Up @@ -129,7 +129,7 @@
#' secondary = secondary_opts(type = "prevalence"),
#' obs = obs_opts(
#' week_effect = FALSE,
#' scale = list(mean = 0.4, sd = 0.1)
#' scale = Normal(mean = 0.4, sd = 0.1)
#' )
#' )
#' plot(prev, primary = TRUE)
Expand Down Expand Up @@ -250,6 +250,15 @@ estimate_secondary <- function(data,
# observation model data
stan_data <- c(stan_data, create_obs_model(obs, dates = reports$date))

stan_data <- c(stan_data, create_stan_params(
frac_obs = obs$scale,
rep_phi = obs$phi,
lower_bounds = c(
frac_obs = 0,
rep_phi = 0
)
))

# update data to use specified priors rather than defaults
stan_data <- update_secondary_args(stan_data,
priors = priors, verbose = verbose
Expand Down Expand Up @@ -674,7 +683,7 @@ forecast_secondary <- function(estimate,

# allocate empty parameters
data <- allocate_empty(
data, c("frac_obs", "delay_params", "rep_phi"),
data, c("params", "delay_params"),
n = data$n
)
data$all_dates <- as.integer(all_dates)
Expand Down
Loading

0 comments on commit 9d733d2

Please sign in to comment.