Skip to content

Commit

Permalink
Add filter_leading_zeros and zero_threshold to `estimate_secondar…
Browse files Browse the repository at this point in the history
…y()` and `estimate_truncation()` (#608)

* Add filter_leading_zeros and zero_threshold to estimate_secondary()

* Add filter_leading_zeros and zero_threshold to estimate_truncation()

* Call create_clean_reported_cases() in estimate_truncation()

* Remove breakpoint column created by create_clean_reported_cases()

* Add tests for the new arguments

* Fix wrong object name

* Remove unnecessary anonymous function to map

* Introduce an add_breakpoints argument to control adding a breakpoint column

* Set add_breakpoint to FALSE to not create a breakpoint column

* Document add_breakpoint argument

* Merge reports with secondary_reports after filtering leading zeroes

* Fix nested ifelse statement

* Remove trailing whitespace

* Update tests/testthat/test-estimate_secondary.R

Co-authored-by: Sebastian Funk <[email protected]>

* Replace anonymous function shorthand

* Add PR and reviewer

Co-authored-by: Sebastian Funk <[email protected]>

* Only replace NA with 0 if breakpoint column is present

* Only merge by date on secondary_reports

Co-authored-by: Sebastian Funk <[email protected]>

* Fix test

Co-authored-by: Sebastian Funk <[email protected]>

* Make sure all the datasets have the same start date

* Improve test to compare aspects of fits to original dataset

* move filtering up

* bring back copying (don't modify input object)

* use rstan in tests

---------

Co-authored-by: Sebastian Funk <[email protected]>
  • Loading branch information
jamesmbaazam and sbfnk authored Mar 27, 2024
1 parent 7f25782 commit 0743eeb
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 5 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
* Added an `na` argument to `obs_opts()` that allows the user to specify whether NA values in the data should be interpreted as missing or accumulated in the next non-NA data point. By @sbfnk in #534 and reviewed by @seabbs.
* Growth rates are now calculated directly from the infection trajectory as `log I(t) - log I(t - 1)`. Originally by @seabbs in #213, finished by @sbfnk in #610 and reviewed by @seabbs.
* Fixed a bug when using the nonmechanistic model that could lead to explosive growth. By @sbfnk in #612 and reviewed by @jamesmbaazam.
* Added the arguments `filter_leading_zeros` and `zero_threshold` to `estimate_secondary()` and `estimate_truncation()` to allow the user to specify whether to filter leading zeros in the data and the threshold for replacing zero cases. These arguments were already used in `estimate_infections()`, `epinow()`, and `regional_epinow()`. See `?estimate_secondary` and `?estimate_truncation` for more details. By @jamesmbaazam in #608 and reviewed by @sbfnk.

# EpiNow2 1.4.0

Expand Down
11 changes: 8 additions & 3 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#' `zero_threshold`. If the default NA is used then dates with NA values or with
#' 7-day averages above the `zero_threshold` will be skipped in model fitting.
#' If this is set to 0 then the only effect is to replace NA values with 0.
#' @param add_breakpoints Logical, defaults to TRUE. Should a breakpoint column
#' be added to the data frame if it does not exist.
#'
#' @inheritParams estimate_infections
#' @importFrom data.table copy merge.data.table setorder setDT frollsum
Expand All @@ -27,7 +29,8 @@
create_clean_reported_cases <- function(reported_cases, horizon = 0,
filter_leading_zeros = TRUE,
zero_threshold = Inf,
fill = NA_integer_) {
fill = NA_integer_,
add_breakpoints = TRUE) {
reported_cases <- data.table::setDT(reported_cases)
reported_cases_grid <- data.table::copy(reported_cases)[,
.(date = seq(min(date), max(date) + horizon, by = "days"))
Expand All @@ -38,10 +41,12 @@ create_clean_reported_cases <- function(reported_cases, horizon = 0,
by = "date", all.y = TRUE
)

if (is.null(reported_cases$breakpoint)) {
if (is.null(reported_cases$breakpoint) && add_breakpoints) {
reported_cases$breakpoint <- 0
}
reported_cases[is.na(breakpoint), breakpoint := 0]
if (!is.null(reported_cases$breakpoint)) {
reported_cases[is.na(breakpoint), breakpoint := 0]
}
reported_cases <- data.table::setorder(reported_cases, date)
## Filter out 0 reported cases from the beginning of the data
if (filter_leading_zeros) {
Expand Down
13 changes: 12 additions & 1 deletion R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ estimate_secondary <- function(reports,
stan = stan_opts(),
burn_in = 14,
CrIs = c(0.2, 0.5, 0.9),
filter_leading_zeros = FALSE,
zero_threshold = Inf,
priors = NULL,
model = NULL,
weigh_delay_priors = FALSE,
Expand All @@ -160,6 +162,8 @@ estimate_secondary <- function(reports,
assert_class(obs, "obs_opts")
assert_numeric(burn_in, lower = 0)
assert_numeric(CrIs, lower = 0, upper = 1)
assert_logical(filter_leading_zeros)
assert_numeric(zero_threshold, lower = 0)
assert_data_frame(priors, null.ok = TRUE)
assert_class(model, "stanfit", null.ok = TRUE)
assert_logical(weigh_delay_priors)
Expand All @@ -168,7 +172,9 @@ estimate_secondary <- function(reports,
reports <- data.table::as.data.table(reports)
secondary_reports <- reports[, list(date, confirm = secondary)]
secondary_reports <- create_clean_reported_cases(
secondary_reports, filter_leading_zeros = FALSE
secondary_reports,
filter_leading_zeros = filter_leading_zeros,
zero_threshold = zero_threshold
)
## fill in missing data (required if fitting to prevalence)
complete_secondary <- create_complete_cases(secondary_reports)
Expand All @@ -178,6 +184,11 @@ estimate_secondary <- function(reports,
## fill any early data up
secondary_reports[, confirm := nafill(confirm, type = "nocb")]

# Ensure that reports and secondary_reports are aligned
reports <- merge.data.table(
reports, secondary_reports[, list(date)], by = "date"
)

if (burn_in >= nrow(reports)) {
stop("burn_in is greater or equal to the number of observations.
Some observations must be used in fitting")
Expand Down
17 changes: 17 additions & 0 deletions R/estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,
model = NULL,
stan = stan_opts(),
CrIs = c(0.2, 0.5, 0.9),
filter_leading_zeros = FALSE,
zero_threshold = Inf,
weigh_delay_priors = FALSE,
verbose = TRUE,
...) {
Expand All @@ -126,6 +128,8 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,
assert_class(truncation, "dist_spec")
assert_class(model, "stanfit", null.ok = TRUE)
assert_numeric(CrIs, lower = 0, upper = 1)
assert_logical(filter_leading_zeros)
assert_numeric(zero_threshold, lower = 0)
assert_logical(weigh_delay_priors)
assert_logical(verbose)

Expand Down Expand Up @@ -193,6 +197,19 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,

# combine into ordered matrix
dirty_obs <- purrr::map(obs, data.table::as.data.table)
dirty_obs <- purrr::map(dirty_obs,
create_clean_reported_cases,
horizon = 0,
filter_leading_zeros = filter_leading_zeros,
zero_threshold = zero_threshold,
add_breakpoints = FALSE
)
earliest_date <- max(
as.Date(
purrr::map_chr(dirty_obs, function(x) x[, as.character(min(date))])
)
)
dirty_obs <- purrr::map(dirty_obs, function(x) x[date >= earliest_date])
nrow_obs <- order(purrr::map_dbl(dirty_obs, nrow))
dirty_obs <- dirty_obs[nrow_obs]
obs <- purrr::map(dirty_obs, data.table::copy)
Expand Down
6 changes: 5 additions & 1 deletion man/create_clean_reported_cases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions man/estimate_secondary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions man/estimate_truncation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions tests/testthat/test-estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,30 @@ test_that("estimate_secondary works with weigh_delay_priors = TRUE", {
)
expect_s3_class(inc_weigh, "estimate_secondary")
})

test_that("estimate_secondary works with filter_leading_zeros set", {
modified_data <- inc_cases[1:10, secondary := 0]
out <- estimate_secondary(
modified_data,
obs = obs_opts(scale = list(mean = 0.2, sd = 0.2),
week_effect = FALSE),
filter_leading_zeros = TRUE,
verbose = FALSE
)
expect_s3_class(out, "estimate_secondary")
expect_named(out, c("predictions", "posterior", "data", "fit"))
expect_equal(out$predictions$primary, modified_data$primary[-(1:10)])
})

test_that("estimate_secondary works with zero_threshold set", {
modified_data <- inc_cases[sample(1:30, 10), primary := 0]
out <- estimate_secondary(
modified_data,
obs = obs_opts(scale = list(mean = 0.2, sd = 0.2),
week_effect = FALSE),
zero_threshold = 10,
verbose = FALSE
)
expect_s3_class(out, "estimate_secondary")
expect_named(out, c("predictions", "posterior", "data", "fit"))
})
48 changes: 48 additions & 0 deletions tests/testthat/test-estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,54 @@ test_that("estimate_truncation can return values from simulated data with the
expect_error(plot(est), NA)
})

test_that("estimate_truncation works with filter_leading_zeros set", {
skip_on_os("windows")
# Modify the first three rows of the first dataset to have zero cases
# and fit the model with filter_leading_zeros = TRUE. This should
# be the same as fitting the model to the original dataset because the
# earlier dataset is corrected to be the same as the final dataset.
modified_data <- data.table::copy(example_truncated)
modified_data[[1]][1:3, confirm := 0]
modified_data_fit <- estimate_truncation(
modified_data,
verbose = FALSE, chains = 2, iter = 1000, warmup = 250,
filter_leading_zeros = TRUE
)
# fit model to original dataset
original_data_fit <- estimate_truncation(
example_truncated,
verbose = FALSE, chains = 2, iter = 1000, warmup = 250
)
expect_named(
modified_data_fit,
c("dist", "obs", "last_obs", "cmf", "data", "fit")
)
# Compare the results of the two fits
expect_equal(
original_data_fit$dist$dist,
modified_data_fit$dist$dist
)
expect_equal(
original_data_fit$data$obs_dist,
modified_data_fit$data$obs_dist
)
})

test_that("estimate_truncation works with zero_threshold set", {
skip_on_os("windows")
# fit model to a modified version of example_data with zero leading cases
# but with filter_leading_zeros = TRUE
modified_data <- example_truncated
modified_data <- purrr::map(modified_data, function(x) x[sample(1:10, 6), confirm := 0])
out <- estimate_truncation(modified_data,
verbose = FALSE, chains = 2, iter = 1000, warmup = 250,
stan = stan_opts(backend = "cmdstanr"),
zero_threshold = 1
)
expect_named(out, c("dist", "obs", "last_obs", "cmf", "data", "fit"))
expect_s3_class(out$dist, "dist_spec")
})

test_that("deprecated arguments are recognised", {
options(warn = 2)
expect_error(estimate_truncation(example_truncated,
Expand Down

0 comments on commit 0743eeb

Please sign in to comment.