diff --git a/NEWS.md b/NEWS.md index 660c682e7..2153948bd 100644 --- a/NEWS.md +++ b/NEWS.md @@ -47,6 +47,7 @@ * Added an `na` argument to `obs_opts()` that allows the user to specify whether NA values in the data should be interpreted as missing or accumulated in the next non-NA data point. By @sbfnk in #534 and reviewed by @seabbs. * Growth rates are now calculated directly from the infection trajectory as `log I(t) - log I(t - 1)`. Originally by @seabbs in #213, finished by @sbfnk in #610 and reviewed by @seabbs. * Fixed a bug when using the nonmechanistic model that could lead to explosive growth. By @sbfnk in #612 and reviewed by @jamesmbaazam. +* Added the arguments `filter_leading_zeros` and `zero_threshold` to `estimate_secondary()` and `estimate_truncation()` to allow the user to specify whether to filter leading zeros in the data and the threshold for replacing zero cases. These arguments were already used in `estimate_infections()`, `epinow()`, and `regional_epinow()`. See `?estimate_secondary` and `?estimate_truncation` for more details. By @jamesmbaazam in #608 and reviewed by @sbfnk. # EpiNow2 1.4.0 diff --git a/R/create.R b/R/create.R index cdc0b855b..d0fb3ed2f 100644 --- a/R/create.R +++ b/R/create.R @@ -17,6 +17,8 @@ #' `zero_threshold`. If the default NA is used then dates with NA values or with #' 7-day averages above the `zero_threshold` will be skipped in model fitting. #' If this is set to 0 then the only effect is to replace NA values with 0. +#' @param add_breakpoints Logical, defaults to TRUE. Should a breakpoint column +#' be added to the data frame if it does not exist. #' #' @inheritParams estimate_infections #' @importFrom data.table copy merge.data.table setorder setDT frollsum @@ -27,7 +29,8 @@ create_clean_reported_cases <- function(reported_cases, horizon = 0, filter_leading_zeros = TRUE, zero_threshold = Inf, - fill = NA_integer_) { + fill = NA_integer_, + add_breakpoints = TRUE) { reported_cases <- data.table::setDT(reported_cases) reported_cases_grid <- data.table::copy(reported_cases)[, .(date = seq(min(date), max(date) + horizon, by = "days")) @@ -38,10 +41,12 @@ create_clean_reported_cases <- function(reported_cases, horizon = 0, by = "date", all.y = TRUE ) - if (is.null(reported_cases$breakpoint)) { + if (is.null(reported_cases$breakpoint) && add_breakpoints) { reported_cases$breakpoint <- 0 } - reported_cases[is.na(breakpoint), breakpoint := 0] + if (!is.null(reported_cases$breakpoint)) { + reported_cases[is.na(breakpoint), breakpoint := 0] + } reported_cases <- data.table::setorder(reported_cases, date) ## Filter out 0 reported cases from the beginning of the data if (filter_leading_zeros) { diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 30068b73a..6213ca6a0 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -147,6 +147,8 @@ estimate_secondary <- function(reports, stan = stan_opts(), burn_in = 14, CrIs = c(0.2, 0.5, 0.9), + filter_leading_zeros = FALSE, + zero_threshold = Inf, priors = NULL, model = NULL, weigh_delay_priors = FALSE, @@ -160,6 +162,8 @@ estimate_secondary <- function(reports, assert_class(obs, "obs_opts") assert_numeric(burn_in, lower = 0) assert_numeric(CrIs, lower = 0, upper = 1) + assert_logical(filter_leading_zeros) + assert_numeric(zero_threshold, lower = 0) assert_data_frame(priors, null.ok = TRUE) assert_class(model, "stanfit", null.ok = TRUE) assert_logical(weigh_delay_priors) @@ -168,7 +172,9 @@ estimate_secondary <- function(reports, reports <- data.table::as.data.table(reports) secondary_reports <- reports[, list(date, confirm = secondary)] secondary_reports <- create_clean_reported_cases( - secondary_reports, filter_leading_zeros = FALSE + secondary_reports, + filter_leading_zeros = filter_leading_zeros, + zero_threshold = zero_threshold ) ## fill in missing data (required if fitting to prevalence) complete_secondary <- create_complete_cases(secondary_reports) @@ -178,6 +184,11 @@ estimate_secondary <- function(reports, ## fill any early data up secondary_reports[, confirm := nafill(confirm, type = "nocb")] + # Ensure that reports and secondary_reports are aligned + reports <- merge.data.table( + reports, secondary_reports[, list(date)], by = "date" + ) + if (burn_in >= nrow(reports)) { stop("burn_in is greater or equal to the number of observations. Some observations must be used in fitting") diff --git a/R/estimate_truncation.R b/R/estimate_truncation.R index f42422620..92c0f4aef 100644 --- a/R/estimate_truncation.R +++ b/R/estimate_truncation.R @@ -110,6 +110,8 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10, model = NULL, stan = stan_opts(), CrIs = c(0.2, 0.5, 0.9), + filter_leading_zeros = FALSE, + zero_threshold = Inf, weigh_delay_priors = FALSE, verbose = TRUE, ...) { @@ -126,6 +128,8 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10, assert_class(truncation, "dist_spec") assert_class(model, "stanfit", null.ok = TRUE) assert_numeric(CrIs, lower = 0, upper = 1) + assert_logical(filter_leading_zeros) + assert_numeric(zero_threshold, lower = 0) assert_logical(weigh_delay_priors) assert_logical(verbose) @@ -193,6 +197,19 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10, # combine into ordered matrix dirty_obs <- purrr::map(obs, data.table::as.data.table) + dirty_obs <- purrr::map(dirty_obs, + create_clean_reported_cases, + horizon = 0, + filter_leading_zeros = filter_leading_zeros, + zero_threshold = zero_threshold, + add_breakpoints = FALSE + ) + earliest_date <- max( + as.Date( + purrr::map_chr(dirty_obs, function(x) x[, as.character(min(date))]) + ) + ) + dirty_obs <- purrr::map(dirty_obs, function(x) x[date >= earliest_date]) nrow_obs <- order(purrr::map_dbl(dirty_obs, nrow)) dirty_obs <- dirty_obs[nrow_obs] obs <- purrr::map(dirty_obs, data.table::copy) diff --git a/man/create_clean_reported_cases.Rd b/man/create_clean_reported_cases.Rd index 811150d9f..8e33f6798 100644 --- a/man/create_clean_reported_cases.Rd +++ b/man/create_clean_reported_cases.Rd @@ -9,7 +9,8 @@ create_clean_reported_cases( horizon = 0, filter_leading_zeros = TRUE, zero_threshold = Inf, - fill = NA_integer_ + fill = NA_integer_, + add_breakpoints = TRUE ) } \arguments{ @@ -32,6 +33,9 @@ zeroes that are flagged because the 7-day average is above the \code{zero_threshold}. If the default NA is used then dates with NA values or with 7-day averages above the \code{zero_threshold} will be skipped in model fitting. If this is set to 0 then the only effect is to replace NA values with 0.} + +\item{add_breakpoints}{Logical, defaults to TRUE. Should a breakpoint column +be added to the data frame if it does not exist.} } \value{ A cleaned data frame of reported cases diff --git a/man/estimate_secondary.Rd b/man/estimate_secondary.Rd index ac10191eb..e8ecfd100 100644 --- a/man/estimate_secondary.Rd +++ b/man/estimate_secondary.Rd @@ -14,6 +14,8 @@ estimate_secondary( stan = stan_opts(), burn_in = 14, CrIs = c(0.2, 0.5, 0.9), + filter_leading_zeros = FALSE, + zero_threshold = Inf, priors = NULL, model = NULL, weigh_delay_priors = FALSE, @@ -53,6 +55,14 @@ This must be less than the number of observations.} \item{CrIs}{Numeric vector of credible intervals to calculate.} +\item{filter_leading_zeros}{Logical, defaults to TRUE. Should zeros at the +start of the time series be filtered out.} + +\item{zero_threshold}{\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} Numeric defaults +to Inf. Indicates if detected zero cases are meaningful by using a threshold +number of cases based on the 7-day average. If the average is above this +threshold then the zero is replaced using \code{fill}.} + \item{priors}{A \verb{} of named priors to be used in model fitting rather than the defaults supplied from other arguments. This is typically useful if wanting to inform an estimate from the posterior of another model diff --git a/man/estimate_truncation.Rd b/man/estimate_truncation.Rd index a46ca2fc5..884d280f7 100644 --- a/man/estimate_truncation.Rd +++ b/man/estimate_truncation.Rd @@ -14,6 +14,8 @@ estimate_truncation( model = NULL, stan = stan_opts(), CrIs = c(0.2, 0.5, 0.9), + filter_leading_zeros = FALSE, + zero_threshold = Inf, weigh_delay_priors = FALSE, verbose = TRUE, ... @@ -44,6 +46,14 @@ settings if desired.} \item{CrIs}{Numeric vector of credible intervals to calculate.} +\item{filter_leading_zeros}{Logical, defaults to TRUE. Should zeros at the +start of the time series be filtered out.} + +\item{zero_threshold}{\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} Numeric defaults +to Inf. Indicates if detected zero cases are meaningful by using a threshold +number of cases based on the 7-day average. If the average is above this +threshold then the zero is replaced using \code{fill}.} + \item{weigh_delay_priors}{Logical. If TRUE, all delay distribution priors will be weighted by the number of observation data points, in doing so approximately placing an independent prior at each time step and usually diff --git a/tests/testthat/test-estimate_secondary.R b/tests/testthat/test-estimate_secondary.R index 12fc95a20..d4c64e948 100644 --- a/tests/testthat/test-estimate_secondary.R +++ b/tests/testthat/test-estimate_secondary.R @@ -202,3 +202,30 @@ test_that("estimate_secondary works with weigh_delay_priors = TRUE", { ) expect_s3_class(inc_weigh, "estimate_secondary") }) + +test_that("estimate_secondary works with filter_leading_zeros set", { + modified_data <- inc_cases[1:10, secondary := 0] + out <- estimate_secondary( + modified_data, + obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), + week_effect = FALSE), + filter_leading_zeros = TRUE, + verbose = FALSE + ) + expect_s3_class(out, "estimate_secondary") + expect_named(out, c("predictions", "posterior", "data", "fit")) + expect_equal(out$predictions$primary, modified_data$primary[-(1:10)]) +}) + +test_that("estimate_secondary works with zero_threshold set", { + modified_data <- inc_cases[sample(1:30, 10), primary := 0] + out <- estimate_secondary( + modified_data, + obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), + week_effect = FALSE), + zero_threshold = 10, + verbose = FALSE + ) + expect_s3_class(out, "estimate_secondary") + expect_named(out, c("predictions", "posterior", "data", "fit")) +}) diff --git a/tests/testthat/test-estimate_truncation.R b/tests/testthat/test-estimate_truncation.R index 9cd2cc35f..b5bbbbd70 100644 --- a/tests/testthat/test-estimate_truncation.R +++ b/tests/testthat/test-estimate_truncation.R @@ -37,6 +37,54 @@ test_that("estimate_truncation can return values from simulated data with the expect_error(plot(est), NA) }) +test_that("estimate_truncation works with filter_leading_zeros set", { + skip_on_os("windows") + # Modify the first three rows of the first dataset to have zero cases + # and fit the model with filter_leading_zeros = TRUE. This should + # be the same as fitting the model to the original dataset because the + # earlier dataset is corrected to be the same as the final dataset. + modified_data <- data.table::copy(example_truncated) + modified_data[[1]][1:3, confirm := 0] + modified_data_fit <- estimate_truncation( + modified_data, + verbose = FALSE, chains = 2, iter = 1000, warmup = 250, + filter_leading_zeros = TRUE + ) + # fit model to original dataset + original_data_fit <- estimate_truncation( + example_truncated, + verbose = FALSE, chains = 2, iter = 1000, warmup = 250 + ) + expect_named( + modified_data_fit, + c("dist", "obs", "last_obs", "cmf", "data", "fit") + ) + # Compare the results of the two fits + expect_equal( + original_data_fit$dist$dist, + modified_data_fit$dist$dist + ) + expect_equal( + original_data_fit$data$obs_dist, + modified_data_fit$data$obs_dist + ) +}) + +test_that("estimate_truncation works with zero_threshold set", { + skip_on_os("windows") + # fit model to a modified version of example_data with zero leading cases + # but with filter_leading_zeros = TRUE + modified_data <- example_truncated + modified_data <- purrr::map(modified_data, function(x) x[sample(1:10, 6), confirm := 0]) + out <- estimate_truncation(modified_data, + verbose = FALSE, chains = 2, iter = 1000, warmup = 250, + stan = stan_opts(backend = "cmdstanr"), + zero_threshold = 1 + ) + expect_named(out, c("dist", "obs", "last_obs", "cmf", "data", "fit")) + expect_s3_class(out$dist, "dist_spec") +}) + test_that("deprecated arguments are recognised", { options(warn = 2) expect_error(estimate_truncation(example_truncated,