Skip to content

Commit

Permalink
Add support for missing NAs in estimate_infection() model (#528)
Browse files Browse the repository at this point in the history
* add a lookup to estimate_infections

* add R side support

* don't internally impute missing as zero

* update news

* fix data preprocessing order

* correction data ingestion

* clean up filtering of leading zeros

* error check create_clean_reported_cases and add unit tests to cover function

* correct handling of missing data in data preprocessing:

* refine data preprocessing

* update news and tests

* update global variables

* Update NEWS.md

Co-authored-by: Sebastian Funk <[email protected]>

* Update R/create.R

Co-authored-by: Sebastian Funk <[email protected]>

* Update R/create.R

Co-authored-by: Sebastian Funk <[email protected]>

* Update R/create.R

Co-authored-by: Sebastian Funk <[email protected]>

* fix line length linting

* Document

---------

Co-authored-by: Sebastian Funk <[email protected]>
Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
3 people authored Jan 9, 2024
1 parent cb5289d commit 380a979
Show file tree
Hide file tree
Showing 13 changed files with 126 additions and 47 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* The functions `get_dist`, `get_generation_time`, `get_incubation_period` have been deprecated and replaced with examples. By @sbfnk in #481 and reviewed by @seabbs.
* The utility function `update_list()` has been deprecated in favour of `utils::modifyList()` because it comes with an installation of R. By @jamesmbaazam in #491 and reviewed by @seabbs.
* The `fixed` argument to `dist_spec` has been deprecated and replaced by a `fix_dist()` function. By @sbfnk in #503 and reviewed by @seabbs.
* Updated `estimate_infections()` so that rather than imputing missing data, it now skips these data points in the likelihood. This is a breaking change as it alters the behaviour of the model when dates are missing from a time series but are known to be zero. We recommend that users check their results when updating to this version but expect this to in most cases improve performance. By @seabbs in #528 and reviewed by @sbfnk.

## Documentation

Expand Down
76 changes: 48 additions & 28 deletions R/create.R
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
#' Create Clean Reported Cases
#' @description `r lifecycle::badge("stable")`
#' Cleans a data frame of reported cases by replacing missing dates with 0
#' cases and applies an optional threshold at which point 0 cases are replaced
#' with a moving average of observed cases. See `zero_threshold` for details.
#' Filters leading zeros, completes dates, and applies an optional threshold at
#' which point 0 cases are replaced with a user supplied value (defaults to
#' `NA`).
#'
#' @param filter_leading_zeros Logical, defaults to TRUE. Should zeros at the
#' start of the time series be filtered out.
#'
#' @param zero_threshold `r lifecycle::badge("experimental")` Numeric defaults
#' to Inf. Indicates if detected zero cases are meaningful by using a threshold
#' number of cases based on the 7 day average. If the average is above this
#' threshold then the zero is replaced with the backwards looking rolling
#' average. If set to infinity then no changes are made.
#' number of cases based on the 7-day average. If the average is above this
#' threshold then the zero is replaced using `fill`.
#'
#' @param fill Numeric, defaults to NA. Value to use to replace NA values or
#' zeroes that are flagged because the 7-day average is above the
#' `zero_threshold`. If the default NA is used then dates with NA values or with
#' 7-day averages above the `zero_threshold` will be skipped in model fitting.
#' If this is set to 0 then the only effect is to replace NA values with 0.
#'
#' @inheritParams estimate_infections
#' @importFrom data.table copy merge.data.table setorder setDT frollsum
#' @return A cleaned data frame of reported cases
#' @author Sam Abbott
#' @author Lloyd Chapman
#' @export
#' @examples
#' create_clean_reported_cases(example_confirmed, 7)
create_clean_reported_cases <- function(reported_cases, horizon,
filter_leading_zeros = TRUE,
zero_threshold = Inf) {
zero_threshold = Inf,
fill = NA_integer_) {
reported_cases <- data.table::setDT(reported_cases)
reported_cases_grid <- data.table::copy(reported_cases)[,
.(date = seq(min(date), max(date) + horizon, by = "days"))
Expand All @@ -35,35 +43,35 @@ create_clean_reported_cases <- function(reported_cases, horizon,
if (is.null(reported_cases$breakpoint)) {
reported_cases$breakpoint <- 0
}
reported_cases <- reported_cases[
is.na(confirm), confirm := 0][, .(date = date, confirm, breakpoint)
]
reported_cases <- reported_cases[is.na(breakpoint), breakpoint := 0]
reported_cases[is.na(breakpoint), breakpoint := 0]
reported_cases <- data.table::setorder(reported_cases, date)
## Filter out 0 reported cases from the beginning of the data
if (filter_leading_zeros) {
reported_cases <- reported_cases[order(date)][
,
cum_cases := cumsum(confirm)
][cum_cases > 0][, cum_cases := NULL]
date >= min(date[confirm[!is.na(confirm)] > 0])
]
}

# Calculate `average_7_day` which for rows with `confirm == 0`
# (the only instance where this is being used) equates to the 7-day
# right-aligned moving average at the previous data point.
reported_cases <-
reported_cases[
,
`:=`(average_7_day = (
data.table::frollsum(confirm, n = 8, na.rm = TRUE)
) / 7
)
]
# Check case counts preceding zero case counts and set to 7 day average if
# average over last 7 days is greater than a threshold
if (!is.infinite(zero_threshold)) {
reported_cases <-
reported_cases[
,
`:=`(average_7 = (data.table::frollsum(confirm, n = 8)) / 7)
]
reported_cases <- reported_cases[
confirm == 0 & average_7 > zero_threshold,
confirm := as.integer(average_7)
][
,
"average_7" := NULL
confirm == 0 & average_7_day > zero_threshold,
confirm := NA_integer_
]
}
reported_cases[is.na(confirm), confirm := fill]
reported_cases[, "average_7_day" := NULL]
return(reported_cases)
}

Expand Down Expand Up @@ -429,14 +437,26 @@ create_obs_model <- function(obs = obs_opts(), dates) {
#' @author Sam Abbott
#' @author Sebastian Funk
#' @export
#' @examples
#' create_stan_data(
#' example_confirmed, 7, rt_opts(), gp_opts(), obs_opts(), 7,
#' backcalc_opts(), create_shifted_cases(example_confirmed, 7, 14, 7)
#' )
create_stan_data <- function(reported_cases, seeding_time,
rt, gp, obs, horizon,
backcalc, shifted_cases) {

cases <- reported_cases[(seeding_time + 1):(.N - horizon)]$confirm
cases <- reported_cases[(seeding_time + 1):(.N - horizon)]
cases[, lookup := seq_len(.N)]
complete_cases <- cases[!is.na(cases$confirm)]
cases_time <- complete_cases$lookup
complete_cases <- complete_cases$confirm
cases <- cases$confirm

data <- list(
cases = cases,
cases = complete_cases,
cases_time = cases_time,
lt = length(cases_time),
shifted_cases = shifted_cases,
t = length(reported_cases$date),
horizon = horizon,
Expand All @@ -455,7 +475,7 @@ create_stan_data <- function(reported_cases, seeding_time,
first_week <- data.table::data.table(
confirm = cases[seq_len(min(7, length(cases)))],
t = seq_len(min(7, length(cases)))
)
)[!is.na(confirm)]
data$prior_infections <- log(mean(first_week$confirm, na.rm = TRUE))
data$prior_infections <- ifelse(
is.na(data$prior_infections) || is.null(data$prior_infections),
Expand Down
2 changes: 1 addition & 1 deletion R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ estimate_infections <- function(reported_cases,
name = "EpiNow2.epinow.estimate_infections"
)
}
# Make sure there are no missing dates and order cases
# Order cases
reported_cases <- create_clean_reported_cases(
reported_cases, horizon,
filter_leading_zeros = filter_leading_zeros,
Expand Down
4 changes: 2 additions & 2 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,9 @@ globalVariables(
"New confirmed cases by infection date", "Data", "R", "reference",
".SD", "day_of_week", "forecast_type", "measure", "numeric_estimate",
"point", "strat", "estimate", "breakpoint", "variable", "value.V1",
"central_lower", "central_upper", "mean_sd", "sd_sd", "average_7",
"central_lower", "central_upper", "mean_sd", "sd_sd", "average_7_day",
"..lowers", "..upper_CrI", "..uppers", "timing", "dataset", "last_confirm",
"report_date", "secondary", "id", "conv", "meanlog", "primary", "scaled",
"scaling", "sdlog"
"scaling", "sdlog", "lookup"
)
)
4 changes: 3 additions & 1 deletion inst/stan/data/observations.stan
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
int t; // unobserved time
int lt; // timepoints in the likelihood
int seeding_time; // time period used for seeding and not observed
int horizon; // forecast horizon
int future_time; // time in future for Rt
array[t - horizon - seeding_time] int<lower = 0> cases; // observed cases
array[lt] int<lower = 0> cases; // observed cases
array[lt] int cases_time; // time of observed cases
vector<lower = 0>[t] shifted_cases; // prior infections (for backcalculation)
5 changes: 3 additions & 2 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ model {
// observed reports from mean of reports (update likelihood)
if (likelihood) {
report_lp(
cases, obs_reports, rep_phi, phi_mean, phi_sd, model_type, obs_weight
cases, obs_reports[cases_time], rep_phi, phi_mean, phi_sd, model_type,
obs_weight
);
}
}
Expand Down Expand Up @@ -191,7 +192,7 @@ generated quantities {
// log likelihood of model
if (return_likelihood) {
log_lik = report_log_lik(
cases, obs_reports, rep_phi, model_type, obs_weight
cases, obs_reports[cases_time], rep_phi, model_type, obs_weight
);
}
}
23 changes: 16 additions & 7 deletions man/create_clean_reported_cases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions man/create_stan_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions man/epinow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions man/estimate_infections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions tests/testthat/test-create_clean_reported_cases.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

test_that("create_clean_reported_cases runs without errors", {
expect_no_error(create_clean_reported_cases(example_confirmed, 7))
})

test_that("create_clean_reported_cases returns a data table", {
result <- create_clean_reported_cases(example_confirmed, 7)
expect_s3_class(result, "data.table")
})

test_that("create_clean_reported_cases filters leading zeros correctly", {
# Modify example_confirmed to have leading zeros
modified_data <- example_confirmed
modified_data[1:3, "confirm"] <- 0

result <- create_clean_reported_cases(modified_data, 7)
# Check if the first row with non-zero cases is retained
expect_equal(
result$date[1], min(modified_data$date[modified_data$confirm > 0])
)
})

test_that("create_clean_reported_cases replaces zero cases correctly", {
# Modify example_confirmed to have zero cases that should be replaced
modified_data <- example_confirmed
modified_data$confirm[10:16] <- 0
threshold <- 10

result <- create_clean_reported_cases(
modified_data, 0, zero_threshold = threshold
)
# Check if zero cases within the threshold are replaced
expect_equal(sum(result$confirm == 0, na.rm = TRUE), 0)
})
Empty file.
8 changes: 8 additions & 0 deletions tests/testthat/test-estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ test_that("estimate_infections successfully returns estimates using default sett
test_estimate_infections(reported_cases)
})

test_that("estimate_infections successfully returns estimates when passed NA values", {
skip_on_cran()
reported_cases_na <- data.table::copy(reported_cases)
reported_cases_na[sample(1:30, 5), confirm := NA]
test_estimate_infections(reported_cases_na)
})


test_that("estimate_infections successfully returns estimates using no delays", {
skip_on_cran()
test_estimate_infections(reported_cases, delay = FALSE)
Expand Down

0 comments on commit 380a979

Please sign in to comment.