Skip to content

Commit

Permalink
add forecast_opts() (#901)
Browse files Browse the repository at this point in the history
* add `forecast_opts()`

* update stan model

* render docs

* update tests

* add news item

* Apply suggestions from code review

Co-authored-by: James Azam <[email protected]>

* imputed_time -> imputed_times

* cases_time -> case_times

* bring back gp_opts

---------

Co-authored-by: James Azam <[email protected]>
  • Loading branch information
sbfnk and jamesmbaazam authored Dec 20, 2024
1 parent 89124a7 commit a5da083
Show file tree
Hide file tree
Showing 34 changed files with 360 additions and 113 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ export(filter_leading_zeros)
export(fix_dist)
export(fix_parameters)
export(forecast_infections)
export(forecast_opts)
export(forecast_secondary)
export(gamma_dist_def)
export(generation_time_opts)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

- The internal functions `create_clean_reported_cases()` has been broken up into several functions, with relevant ones `filter_leading_zeros()`, `add_breakpoints()` and `apply_zero_threshold()` exposed to the user. By @sbfnk in #884 and reviewed by @seabbs and @jamesmbaazam.
- The step of estimating early infections and growth in the internal function `create_stan_data()` has been separated into a new internal function `estimate_early_dynamics()`. By @jamesmbaazam in #888 and reviewed by @sbfnk.
- `estimate_infections()` and `epinow()` gain the `forecast` argument for setting the forecast horizon (`horizon`) and accumulation of forecasts. `forecast` is set with the `forecast_opts()` function similar to the other settings arguments. By @sbfnk in #901 and reviewed by @jamesmbaazam.

## Documentation

Expand Down
91 changes: 57 additions & 34 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,19 @@
#' \dontrun{
#' create_clean_reported_cases(example_confirmed, 7)
#' }
create_clean_reported_cases <- function(data, horizon = 0,
create_clean_reported_cases <- function(data,
filter_leading_zeros = TRUE,
zero_threshold = Inf,
fill = NA_integer_,
add_breakpoints = TRUE) {
reported_cases <- add_horizon(data, horizon = horizon)
if (add_breakpoints) {
reported_cases <- add_breakpoints(reported_cases)
data <- add_breakpoints(data)
}
if (filter_leading_zeros) {
reported_cases <- filter_leading_zeros(reported_cases)
data <- filter_leading_zeros(data)
}
reported_cases <- apply_zero_threshold(reported_cases, zero_threshold)
return(reported_cases[])
}

#' Create complete cases
#' @description `r lifecycle::badge("stable")`
#' Creates a complete data set without NA values and appropriate indices
#'
#' @param cases data frame with a column "confirm" that may contain NA values
#'
#' @return A data frame without NA values, with two columns: confirm (number)
#' @importFrom data.table setDT
#' @keywords internal
create_complete_cases <- function(cases) {
cases <- setDT(cases)
cases[, lookup := seq_len(.N)]
cases <- cases[!is.na(cases$confirm)]
return(cases[])
data <- apply_zero_threshold(data, zero_threshold)
return(data[])
}

#' Create Delay Shifted Cases
Expand Down Expand Up @@ -428,6 +411,38 @@ create_obs_model <- function(obs = obs_opts(), dates) {
return(data)
}

##' Create forecast settings
##'
##' @param forecast A list of options as generated by [forecast_opts()] defining
##' the forecast opitions. Defaults to [forecast_opts()]. If NULL then no
##' forecasting will be done.
##' @inheritParams create_stan_data
##' @return A list of settings ready to be passed to stan defining
##' the Observation Model
##' @keywords internal
create_forecast_data <- function(forecast = forecast_opts(), data) {
if (is.null(forecast)) {
forecast <- forecast_opts(horizon = 0)
}
if (forecast$infer_accumulate && any(data$accumulate)) {
accumulation_times <- which(!data$accumulate)
gaps <- unique(diff(accumulation_times))
if (length(gaps) == 1 && gaps > 1) { ## all gaps are the same
forecast$accumulate <- gaps
cli_inform(c(
"i" = "Forecasts accumulated every {gaps} days, same as accumulation
used in the likelihood. To change this behaviour or silence this
message set {.var accumulate} explicitly in {.fn forecast_opts}."
))
}
}
data <- list(
horizon = forecast$horizon,
future_accumulate = forecast$accumulate
)
return(data)
}

#' Calculate prior infections and fit early growth
#'
#' @description Calculates the prior infections and growth rate based on the
Expand Down Expand Up @@ -484,6 +499,7 @@ estimate_early_dynamics <- function(cases, seeding_time) {
#' @inheritParams create_obs_model
#' @inheritParams create_rt_data
#' @inheritParams create_backcalc_data
#' @inheritParams create_forecast_data
#' @importFrom stats lm
#' @importFrom purrr safely
#' @return A list of stan data
Expand All @@ -495,26 +511,33 @@ estimate_early_dynamics <- function(cases, seeding_time) {
#' backcalc_opts(), create_shifted_cases(example_confirmed, 7, 14, 7)
#' )
#' }
create_stan_data <- function(data, seeding_time,
rt, gp, obs, horizon,
backcalc, shifted_cases) {
cases <- data[(seeding_time + 1):(.N - horizon)]
complete_cases <- create_complete_cases(cases)
cases <- cases$confirm
accumulate <- data[-(1:seeding_time)]$accumulate
create_stan_data <- function(data, seeding_time, rt, gp, obs, backcalc,
shifted_cases, forecast) {
cases <- data[(seeding_time + 1):.N]
cases[, lookup := seq_len(.N)]
case_times <- cases[!is.na(confirm), lookup]
imputed_times <- cases[!(accumulate), lookup]
accumulate <- cases$accumulate
confirmed_cases <- cases[1:(.N - forecast$horizon)]$confirm

stan_data <- list(
cases = complete_cases$confirm,
cases_time = complete_cases$lookup,
cases = confirmed_cases[!is.na(confirmed_cases)],
any_accumulate = as.integer(any(accumulate)),
case_times = as.integer(case_times),
imputed_times = as.integer(imputed_times),
accumulate = as.integer(accumulate),
lt = nrow(complete_cases),
lt = length(case_times),
it = length(imputed_times),
shifted_cases = shifted_cases,
t = length(data$date),
horizon = horizon,
burn_in = 0,
seeding_time = seeding_time
)
# add forecast data
stan_data <- c(
stan_data,
create_forecast_data(forecast, cases)
)
# add Rt data
stan_data <- c(
stan_data,
Expand All @@ -526,7 +549,7 @@ create_stan_data <- function(data, seeding_time,
# calculate prior infections and fit early growth
stan_data <- c(
stan_data,
estimate_early_dynamics(cases, seeding_time)
estimate_early_dynamics(confirmed_cases, seeding_time)
)
# backcalculation settings
stan_data <- c(stan_data, create_backcalc_data(backcalc))
Expand Down
2 changes: 2 additions & 0 deletions R/epinow-internal.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#' beyond the target date.
#' @inheritParams setup_target_folder
#' @inheritParams estimate_infections
#' @param horizon Numeric, defaults to 7. Number of days into the future to
#' forecast.
#' @return Numeric forecast horizon adjusted for the users intention
#' @keywords internal
update_horizon <- function(horizon, target_date, data) {
Expand Down
21 changes: 18 additions & 3 deletions R/epinow.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ epinow <- function(data,
backcalc = backcalc_opts(),
gp = gp_opts(),
obs = obs_opts(),
forecast = forecast_opts(),
stan = stan_opts(),
horizon = 7,
horizon,
CrIs = c(0.2, 0.5, 0.9),
filter_leading_zeros = TRUE,
zero_threshold = Inf,
Expand All @@ -105,7 +106,21 @@ epinow <- function(data,
"epinow(data)"
)
}
if (!missing(horizon)) {
lifecycle::deprecate_warn(
"1.7.0",
"epinow(horizon)",
"epinow(forecast)",
details = "The `horizon` argument passed to `epinow()` will
override any `horizon` argument passed via `forecast_opts()`."
)
}
# Check inputs
## deprecated
if (!missing(horizon)) {
assert_numeric(horizon, lower = 0)
forecast$horizon <- horizon
}
assert_logical(return_output)
stopifnot("target_folder is not a directory" =
!is.null(target_folder) || isDirectory(target_folder)
Expand Down Expand Up @@ -174,7 +189,7 @@ epinow <- function(data,
save_input(reported_cases, target_folder)

# make sure the horizon is as specified from the target date --------------
horizon <- update_horizon(horizon, target_date, reported_cases)
horizon <- update_horizon(forecast$horizon, target_date, reported_cases)

# estimate infections and Reproduction no ---------------------------------
estimates <- estimate_infections(
Expand All @@ -186,11 +201,11 @@ epinow <- function(data,
backcalc = backcalc,
gp = gp,
obs = obs,
forecast = forecast,
stan = stan,
CrIs = CrIs,
filter_leading_zeros = filter_leading_zeros,
zero_threshold = zero_threshold,
horizon = horizon,
verbose = verbose,
id = id
)
Expand Down
44 changes: 34 additions & 10 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
#' used as the `truncation` argument here, thereby propagating the uncertainty
#' in the estimate.
#'
#' @param horizon Numeric, defaults to 7. Number of days into the future to
#' forecast.
#' @param horizon Deprecated; use `forecast` instead to specify the predictive
#' horizon
#'
#' @param weigh_delay_priors Logical. If TRUE (default), all delay distribution
#' priors will be weighted by the number of observation data points, in doing so
Expand All @@ -65,7 +65,7 @@
#' [estimate_truncation()]
#' @inheritParams create_stan_args
#' @inheritParams create_stan_data
#' @inheritParams create_stan_data
#' @inheritParams create_forecast_data
#' @inheritParams create_gp_data
#' @inheritParams fit_model_with_nuts
#' @inheritParams create_clean_reported_cases
Expand Down Expand Up @@ -123,8 +123,9 @@ estimate_infections <- function(data,
backcalc = backcalc_opts(),
gp = gp_opts(),
obs = obs_opts(),
forecast = forecast_opts(),
stan = stan_opts(),
horizon = 7,
horizon,
CrIs = c(0.2, 0.5, 0.9),
filter_leading_zeros = TRUE,
zero_threshold = Inf,
Expand Down Expand Up @@ -154,6 +155,15 @@ estimate_infections <- function(data,
"apply_zero_threshold()"
)
}
if (!missing(horizon)) {
lifecycle::deprecate_warn(
"1.7.0",
"estimate_infections(horizon)",
"estimate_infections(forecast)",
details = "The `horizon` argument passed to `estimate_infections()` will
override any `horizon` argument passed via `forecast_opts()`."
)
}
# Validate inputs
check_reports_valid(data, model = "estimate_infections")
assert_class(generation_time, "generation_time_opts")
Expand All @@ -163,8 +173,13 @@ estimate_infections <- function(data,
assert_class(backcalc, "backcalc_opts")
assert_class(gp, "gp_opts", null.ok = TRUE)
assert_class(obs, "obs_opts")
assert_class(forecast, "forecast_opts", null.ok = TRUE)
assert_class(stan, "stan_opts")
assert_numeric(horizon, lower = 0)
## deprecated
if (!missing(horizon)) {
assert_numeric(horizon, lower = 0)
forecast$horizon <- horizon
}
assert_numeric(CrIs, lower = 0, upper = 1)
assert_logical(filter_leading_zeros)
assert_numeric(zero_threshold, lower = 0)
Expand Down Expand Up @@ -211,9 +226,16 @@ estimate_infections <- function(data,
))
}

## add forecast horizon
if (!is.null(forecast)) {
reported_cases <- add_horizon(
reported_cases, forecast$horizon, forecast$accumulate
)
}

# Create clean and complete cases
reported_cases <- create_clean_reported_cases(
reported_cases, horizon,
reported_cases,
filter_leading_zeros = filter_leading_zeros,
zero_threshold = zero_threshold
)
Expand All @@ -240,7 +262,7 @@ estimate_infections <- function(data,
reported_cases,
seeding_time,
backcalc$prior_window,
horizon
forecast$horizon
)
reported_cases <- reported_cases[-(1:backcalc$prior_window)]

Expand All @@ -253,7 +275,7 @@ estimate_infections <- function(data,
obs = obs,
backcalc = backcalc,
shifted_cases = shifted_cases$confirm,
horizon = horizon
forecast = forecast
)

stan_data <- c(stan_data, create_stan_delays(
Expand Down Expand Up @@ -293,7 +315,9 @@ estimate_infections <- function(data,
# Extract parameters of interest from the fit
out <- extract_parameter_samples(fit, stan_data,
reported_inf_dates = reported_cases$date,
reported_dates = reported_cases$date[-(1:stan_data$seeding_time)]
reported_dates = reported_cases$date[-(1:stan_data$seeding_time)],
imputed_dates =
reported_cases$date[-(1:stan_data$seeding_time)][stan_data$imputed_times]
)

## Add prior infections
Expand All @@ -309,7 +333,7 @@ estimate_infections <- function(data,
# Format output
format_out <- format_fit(
posterior_samples = out,
horizon = horizon,
horizon = stan_data$horizon,
shift = stan_data$seeding_time,
burn_in = 0,
start_date = start_date,
Expand Down
3 changes: 2 additions & 1 deletion R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ estimate_secondary <- function(data,
zero_threshold = zero_threshold
)
## fill in missing data (required if fitting to prevalence)
complete_secondary <- create_complete_cases(secondary_reports)
secondary_reports[, lookup := seq_len(.N)]
complete_secondary <- secondary_reports[!is.na(confirm)]
## fill down
secondary_reports[, confirm := nafill(confirm, type = "locf")]
## fill any early data up
Expand Down
1 change: 0 additions & 1 deletion R/estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ estimate_truncation <- function(data,
dirty_obs <- purrr::map(data, data.table::as.data.table)
dirty_obs <- purrr::map(dirty_obs,
create_clean_reported_cases,
horizon = 0,
filter_leading_zeros = filter_leading_zeros,
zero_threshold = zero_threshold,
add_breakpoints = FALSE
Expand Down
8 changes: 5 additions & 3 deletions R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ extract_samples <- function(stan_fit, pars = NULL, include = TRUE) {
#'
#' @param reported_dates A vector of dates to report estimates for.
#'
#' @param reported_inf_dates A vector of dates to report infection estimates
#' @param imputed_dates A vector of dates to report imputed reports for.
#'
##' @param reported_inf_dates A vector of dates to report infection estimates
#' for.
#'
#' @param drop_length_1 Logical; whether the first dimension should be dropped
Expand All @@ -155,7 +157,7 @@ extract_samples <- function(stan_fit, pars = NULL, include = TRUE) {
#' @importFrom data.table data.table
#' @keywords internal
extract_parameter_samples <- function(stan_fit, data, reported_dates,
reported_inf_dates,
imputed_dates, reported_inf_dates,
drop_length_1 = FALSE, merge = FALSE) {
# extract sample from stan object
samples <- extract_samples(stan_fit)
Expand Down Expand Up @@ -186,7 +188,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
out$reported_cases <- extract_parameter(
"imputed_reports",
samples,
reported_dates
imputed_dates
)
if ("estimate_r" %in% names(data)) {
if (data$estimate_r == 1) {
Expand Down
Loading

0 comments on commit a5da083

Please sign in to comment.