From 380a979becb370f581541b75e87251de4a058a03 Mon Sep 17 00:00:00 2001 From: Sam Abbott Date: Tue, 9 Jan 2024 09:29:44 +0000 Subject: [PATCH] Add support for missing NAs in estimate_infection() model (#528) * add a lookup to estimate_infections * add R side support * don't internally impute missing as zero * update news * fix data preprocessing order * correction data ingestion * clean up filtering of leading zeros * error check create_clean_reported_cases and add unit tests to cover function * correct handling of missing data in data preprocessing: * refine data preprocessing * update news and tests * update global variables * Update NEWS.md Co-authored-by: Sebastian Funk * Update R/create.R Co-authored-by: Sebastian Funk * Update R/create.R Co-authored-by: Sebastian Funk * Update R/create.R Co-authored-by: Sebastian Funk * fix line length linting * Document --------- Co-authored-by: Sebastian Funk Co-authored-by: GitHub Actions --- NEWS.md | 1 + R/create.R | 76 ++++++++++++------- R/estimate_infections.R | 2 +- R/utilities.R | 4 +- inst/stan/data/observations.stan | 4 +- inst/stan/estimate_infections.stan | 5 +- man/create_clean_reported_cases.Rd | 23 ++++-- man/create_stan_data.Rd | 6 ++ man/epinow.Rd | 5 +- man/estimate_infections.Rd | 5 +- .../test-create_clean_reported_cases.R | 34 +++++++++ tests/testthat/test-create_stan_data.R | 0 tests/testthat/test-estimate_infections.R | 8 ++ 13 files changed, 126 insertions(+), 47 deletions(-) create mode 100644 tests/testthat/test-create_clean_reported_cases.R create mode 100644 tests/testthat/test-create_stan_data.R diff --git a/NEWS.md b/NEWS.md index db310583f..d6c62f148 100644 --- a/NEWS.md +++ b/NEWS.md @@ -5,6 +5,7 @@ * The functions `get_dist`, `get_generation_time`, `get_incubation_period` have been deprecated and replaced with examples. By @sbfnk in #481 and reviewed by @seabbs. * The utility function `update_list()` has been deprecated in favour of `utils::modifyList()` because it comes with an installation of R. By @jamesmbaazam in #491 and reviewed by @seabbs. * The `fixed` argument to `dist_spec` has been deprecated and replaced by a `fix_dist()` function. By @sbfnk in #503 and reviewed by @seabbs. +* Updated `estimate_infections()` so that rather than imputing missing data, it now skips these data points in the likelihood. This is a breaking change as it alters the behaviour of the model when dates are missing from a time series but are known to be zero. We recommend that users check their results when updating to this version but expect this to in most cases improve performance. By @seabbs in #528 and reviewed by @sbfnk. ## Documentation diff --git a/R/create.R b/R/create.R index 5b6a9ee83..1b73f26e9 100644 --- a/R/create.R +++ b/R/create.R @@ -1,17 +1,22 @@ #' Create Clean Reported Cases #' @description `r lifecycle::badge("stable")` -#' Cleans a data frame of reported cases by replacing missing dates with 0 -#' cases and applies an optional threshold at which point 0 cases are replaced -#' with a moving average of observed cases. See `zero_threshold` for details. +#' Filters leading zeros, completes dates, and applies an optional threshold at +#' which point 0 cases are replaced with a user supplied value (defaults to +#' `NA`). #' #' @param filter_leading_zeros Logical, defaults to TRUE. Should zeros at the #' start of the time series be filtered out. #' #' @param zero_threshold `r lifecycle::badge("experimental")` Numeric defaults #' to Inf. Indicates if detected zero cases are meaningful by using a threshold -#' number of cases based on the 7 day average. If the average is above this -#' threshold then the zero is replaced with the backwards looking rolling -#' average. If set to infinity then no changes are made. +#' number of cases based on the 7-day average. If the average is above this +#' threshold then the zero is replaced using `fill`. +#' +#' @param fill Numeric, defaults to NA. Value to use to replace NA values or +#' zeroes that are flagged because the 7-day average is above the +#' `zero_threshold`. If the default NA is used then dates with NA values or with +#' 7-day averages above the `zero_threshold` will be skipped in model fitting. +#' If this is set to 0 then the only effect is to replace NA values with 0. #' #' @inheritParams estimate_infections #' @importFrom data.table copy merge.data.table setorder setDT frollsum @@ -19,9 +24,12 @@ #' @author Sam Abbott #' @author Lloyd Chapman #' @export +#' @examples +#' create_clean_reported_cases(example_confirmed, 7) create_clean_reported_cases <- function(reported_cases, horizon, filter_leading_zeros = TRUE, - zero_threshold = Inf) { + zero_threshold = Inf, + fill = NA_integer_) { reported_cases <- data.table::setDT(reported_cases) reported_cases_grid <- data.table::copy(reported_cases)[, .(date = seq(min(date), max(date) + horizon, by = "days")) @@ -35,35 +43,35 @@ create_clean_reported_cases <- function(reported_cases, horizon, if (is.null(reported_cases$breakpoint)) { reported_cases$breakpoint <- 0 } - reported_cases <- reported_cases[ - is.na(confirm), confirm := 0][, .(date = date, confirm, breakpoint) - ] - reported_cases <- reported_cases[is.na(breakpoint), breakpoint := 0] + reported_cases[is.na(breakpoint), breakpoint := 0] reported_cases <- data.table::setorder(reported_cases, date) ## Filter out 0 reported cases from the beginning of the data if (filter_leading_zeros) { reported_cases <- reported_cases[order(date)][ - , - cum_cases := cumsum(confirm) - ][cum_cases > 0][, cum_cases := NULL] + date >= min(date[confirm[!is.na(confirm)] > 0]) + ] } - + # Calculate `average_7_day` which for rows with `confirm == 0` + # (the only instance where this is being used) equates to the 7-day + # right-aligned moving average at the previous data point. + reported_cases <- + reported_cases[ + , + `:=`(average_7_day = ( + data.table::frollsum(confirm, n = 8, na.rm = TRUE) + ) / 7 + ) + ] # Check case counts preceding zero case counts and set to 7 day average if # average over last 7 days is greater than a threshold if (!is.infinite(zero_threshold)) { - reported_cases <- - reported_cases[ - , - `:=`(average_7 = (data.table::frollsum(confirm, n = 8)) / 7) - ] reported_cases <- reported_cases[ - confirm == 0 & average_7 > zero_threshold, - confirm := as.integer(average_7) - ][ - , - "average_7" := NULL + confirm == 0 & average_7_day > zero_threshold, + confirm := NA_integer_ ] } + reported_cases[is.na(confirm), confirm := fill] + reported_cases[, "average_7_day" := NULL] return(reported_cases) } @@ -429,14 +437,26 @@ create_obs_model <- function(obs = obs_opts(), dates) { #' @author Sam Abbott #' @author Sebastian Funk #' @export +#' @examples +#' create_stan_data( +#' example_confirmed, 7, rt_opts(), gp_opts(), obs_opts(), 7, +#' backcalc_opts(), create_shifted_cases(example_confirmed, 7, 14, 7) +#' ) create_stan_data <- function(reported_cases, seeding_time, rt, gp, obs, horizon, backcalc, shifted_cases) { - cases <- reported_cases[(seeding_time + 1):(.N - horizon)]$confirm + cases <- reported_cases[(seeding_time + 1):(.N - horizon)] + cases[, lookup := seq_len(.N)] + complete_cases <- cases[!is.na(cases$confirm)] + cases_time <- complete_cases$lookup + complete_cases <- complete_cases$confirm + cases <- cases$confirm data <- list( - cases = cases, + cases = complete_cases, + cases_time = cases_time, + lt = length(cases_time), shifted_cases = shifted_cases, t = length(reported_cases$date), horizon = horizon, @@ -455,7 +475,7 @@ create_stan_data <- function(reported_cases, seeding_time, first_week <- data.table::data.table( confirm = cases[seq_len(min(7, length(cases)))], t = seq_len(min(7, length(cases))) - ) + )[!is.na(confirm)] data$prior_infections <- log(mean(first_week$confirm, na.rm = TRUE)) data$prior_infections <- ifelse( is.na(data$prior_infections) || is.null(data$prior_infections), diff --git a/R/estimate_infections.R b/R/estimate_infections.R index 69970ce3e..63a945919 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -165,7 +165,7 @@ estimate_infections <- function(reported_cases, name = "EpiNow2.epinow.estimate_infections" ) } - # Make sure there are no missing dates and order cases + # Order cases reported_cases <- create_clean_reported_cases( reported_cases, horizon, filter_leading_zeros = filter_leading_zeros, diff --git a/R/utilities.R b/R/utilities.R index 9c2a590cc..8dfde59a0 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -444,9 +444,9 @@ globalVariables( "New confirmed cases by infection date", "Data", "R", "reference", ".SD", "day_of_week", "forecast_type", "measure", "numeric_estimate", "point", "strat", "estimate", "breakpoint", "variable", "value.V1", - "central_lower", "central_upper", "mean_sd", "sd_sd", "average_7", + "central_lower", "central_upper", "mean_sd", "sd_sd", "average_7_day", "..lowers", "..upper_CrI", "..uppers", "timing", "dataset", "last_confirm", "report_date", "secondary", "id", "conv", "meanlog", "primary", "scaled", - "scaling", "sdlog" + "scaling", "sdlog", "lookup" ) ) diff --git a/inst/stan/data/observations.stan b/inst/stan/data/observations.stan index 654304f85..11fe8463c 100644 --- a/inst/stan/data/observations.stan +++ b/inst/stan/data/observations.stan @@ -1,6 +1,8 @@ int t; // unobserved time + int lt; // timepoints in the likelihood int seeding_time; // time period used for seeding and not observed int horizon; // forecast horizon int future_time; // time in future for Rt - array[t - horizon - seeding_time] int cases; // observed cases + array[lt] int cases; // observed cases + array[lt] int cases_time; // time of observed cases vector[t] shifted_cases; // prior infections (for backcalculation) diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index a97bb57e5..c1ac8c63e 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -148,7 +148,8 @@ model { // observed reports from mean of reports (update likelihood) if (likelihood) { report_lp( - cases, obs_reports, rep_phi, phi_mean, phi_sd, model_type, obs_weight + cases, obs_reports[cases_time], rep_phi, phi_mean, phi_sd, model_type, + obs_weight ); } } @@ -191,7 +192,7 @@ generated quantities { // log likelihood of model if (return_likelihood) { log_lik = report_log_lik( - cases, obs_reports, rep_phi, model_type, obs_weight + cases, obs_reports[cases_time], rep_phi, model_type, obs_weight ); } } diff --git a/man/create_clean_reported_cases.Rd b/man/create_clean_reported_cases.Rd index fddf60d2c..c53830c0c 100644 --- a/man/create_clean_reported_cases.Rd +++ b/man/create_clean_reported_cases.Rd @@ -8,7 +8,8 @@ create_clean_reported_cases( reported_cases, horizon, filter_leading_zeros = TRUE, - zero_threshold = Inf + zero_threshold = Inf, + fill = NA_integer_ ) } \arguments{ @@ -23,18 +24,26 @@ start of the time series be filtered out.} \item{zero_threshold}{\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} Numeric defaults to Inf. Indicates if detected zero cases are meaningful by using a threshold -number of cases based on the 7 day average. If the average is above this -threshold then the zero is replaced with the backwards looking rolling -average. If set to infinity then no changes are made.} +number of cases based on the 7-day average. If the average is above this +threshold then the zero is replaced using \code{fill}.} + +\item{fill}{Numeric, defaults to NA. Value to use to replace NA values or +zeroes that are flagged because the 7-day average is above the +\code{zero_threshold}. If the default NA is used then dates with NA values or with +7-day averages above the \code{zero_threshold} will be skipped in model fitting. +If this is set to 0 then the only effect is to replace NA values with 0.} } \value{ A cleaned data frame of reported cases } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} -Cleans a data frame of reported cases by replacing missing dates with 0 -cases and applies an optional threshold at which point 0 cases are replaced -with a moving average of observed cases. See \code{zero_threshold} for details. +Filters leading zeros, completes dates, and applies an optional threshold at +which point 0 cases are replaced with a user supplied value (defaults to +\code{NA}). +} +\examples{ +create_clean_reported_cases(example_confirmed, 7) } \author{ Sam Abbott diff --git a/man/create_stan_data.Rd b/man/create_stan_data.Rd index b47ca066e..9ec1a8e3d 100644 --- a/man/create_stan_data.Rd +++ b/man/create_stan_data.Rd @@ -50,6 +50,12 @@ stan. Internally calls the other \code{create_} family of functions to construct a single list for input into stan with all data required present. } +\examples{ +create_stan_data( + example_confirmed, 7, rt_opts(), gp_opts(), obs_opts(), 7, + backcalc_opts(), create_shifted_cases(example_confirmed, 7, 14, 7) +) +} \author{ Sam Abbott diff --git a/man/epinow.Rd b/man/epinow.Rd index 36225af0b..a4a302f8a 100644 --- a/man/epinow.Rd +++ b/man/epinow.Rd @@ -72,9 +72,8 @@ start of the time series be filtered out.} \item{zero_threshold}{\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} Numeric defaults to Inf. Indicates if detected zero cases are meaningful by using a threshold -number of cases based on the 7 day average. If the average is above this -threshold then the zero is replaced with the backwards looking rolling -average. If set to infinity then no changes are made.} +number of cases based on the 7-day average. If the average is above this +threshold then the zero is replaced using \code{fill}.} \item{return_output}{Logical, defaults to FALSE. Should output be returned, this automatically updates to TRUE if no directory for saving is specified.} diff --git a/man/estimate_infections.Rd b/man/estimate_infections.Rd index 087043752..851e3a295 100644 --- a/man/estimate_infections.Rd +++ b/man/estimate_infections.Rd @@ -68,9 +68,8 @@ start of the time series be filtered out.} \item{zero_threshold}{\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} Numeric defaults to Inf. Indicates if detected zero cases are meaningful by using a threshold -number of cases based on the 7 day average. If the average is above this -threshold then the zero is replaced with the backwards looking rolling -average. If set to infinity then no changes are made.} +number of cases based on the 7-day average. If the average is above this +threshold then the zero is replaced using \code{fill}.} \item{weigh_delay_priors}{Logical. If TRUE (default), all delay distribution priors will be weighted by the number of observation data points, in doing so diff --git a/tests/testthat/test-create_clean_reported_cases.R b/tests/testthat/test-create_clean_reported_cases.R new file mode 100644 index 000000000..44bb0ef78 --- /dev/null +++ b/tests/testthat/test-create_clean_reported_cases.R @@ -0,0 +1,34 @@ + +test_that("create_clean_reported_cases runs without errors", { + expect_no_error(create_clean_reported_cases(example_confirmed, 7)) +}) + +test_that("create_clean_reported_cases returns a data table", { + result <- create_clean_reported_cases(example_confirmed, 7) + expect_s3_class(result, "data.table") +}) + +test_that("create_clean_reported_cases filters leading zeros correctly", { + # Modify example_confirmed to have leading zeros + modified_data <- example_confirmed + modified_data[1:3, "confirm"] <- 0 + + result <- create_clean_reported_cases(modified_data, 7) + # Check if the first row with non-zero cases is retained + expect_equal( + result$date[1], min(modified_data$date[modified_data$confirm > 0]) + ) +}) + +test_that("create_clean_reported_cases replaces zero cases correctly", { + # Modify example_confirmed to have zero cases that should be replaced + modified_data <- example_confirmed + modified_data$confirm[10:16] <- 0 + threshold <- 10 + + result <- create_clean_reported_cases( + modified_data, 0, zero_threshold = threshold + ) + # Check if zero cases within the threshold are replaced + expect_equal(sum(result$confirm == 0, na.rm = TRUE), 0) +}) diff --git a/tests/testthat/test-create_stan_data.R b/tests/testthat/test-create_stan_data.R new file mode 100644 index 000000000..e69de29bb diff --git a/tests/testthat/test-estimate_infections.R b/tests/testthat/test-estimate_infections.R index 69303de5d..ecb35a2d6 100644 --- a/tests/testthat/test-estimate_infections.R +++ b/tests/testthat/test-estimate_infections.R @@ -36,6 +36,14 @@ test_that("estimate_infections successfully returns estimates using default sett test_estimate_infections(reported_cases) }) +test_that("estimate_infections successfully returns estimates when passed NA values", { + skip_on_cran() + reported_cases_na <- data.table::copy(reported_cases) + reported_cases_na[sample(1:30, 5), confirm := NA] + test_estimate_infections(reported_cases_na) +}) + + test_that("estimate_infections successfully returns estimates using no delays", { skip_on_cran() test_estimate_infections(reported_cases, delay = FALSE)