Skip to content

Commit

Permalink
break up create_clean_reported_cases()
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Dec 7, 2024
1 parent 9d733d2 commit b12cb14
Show file tree
Hide file tree
Showing 15 changed files with 387 additions and 61 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export(LogNormal)
export(NonParametric)
export(Normal)
export(R_to_growth)
export(add_breakpoints)
export(adjust_infection_to_report)
export(apply_tolerance)
export(backcalc_opts)
Expand Down Expand Up @@ -64,6 +65,7 @@ export(extract_inits)
export(extract_samples)
export(extract_stan_param)
export(fill_missing)
export(filter_leading_zeros)
export(fix_dist)
export(fix_parameters)
export(forecast_infections)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and reviewed by @seabbs.

## Package changes

- The internal functions `create_clean_reported_cases()` has been broken up into several functions, with relevant ones `filter_leading_zeros()`, `add_breakpoints()` and `apply_zero_threshold()` exposed to the user. By @sbfnk in # and reviewed by @.

## Documentation

- Brought the docs on `alpha_sd` up to date with the code change from prior PR #853. By @zsusswein in #862 and reviewed by @jamesmbaazam.
Expand Down
62 changes: 9 additions & 53 deletions R/create.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#' Create Clean Reported Cases
#' @description `r lifecycle::badge("stable")`
#' @description `r lifecycle::badge("deprecated")`
#' Filters leading zeros, completes dates, and applies an optional threshold at
#' which point 0 cases are replaced with a user supplied value (defaults to
#' `NA`).
Expand All @@ -12,16 +12,12 @@
#' number of cases based on the 7-day average. If the average is above this
#' threshold then the zero is replaced using `fill`.
#'
#' @param fill Numeric, defaults to NA. Value to use to replace NA values or
#' zeroes that are flagged because the 7-day average is above the
#' `zero_threshold`. If the default NA is used then dates with NA values or with
#' 7-day averages above the `zero_threshold` will be skipped in model fitting.
#' If this is set to 0 then the only effect is to replace NA values with 0.
#' @param fill Deprecated; zero dates with 7-day averages above the
#' `zero_threshold` will be skipped in model fitting.
#' @param add_breakpoints Logical, defaults to TRUE. Should a breakpoint column
#' be added to the data frame if it does not exist.
#'
#' @inheritParams estimate_infections
#' @importFrom data.table copy merge.data.table setorder setDT frollsum
#' @return A cleaned data frame of reported cases
#' @keywords internal
#' @examples
Expand All @@ -33,55 +29,15 @@ create_clean_reported_cases <- function(data, horizon = 0,
zero_threshold = Inf,
fill = NA_integer_,
add_breakpoints = TRUE) {
reported_cases <- data.table::setDT(data)
reported_cases_grid <- data.table::copy(reported_cases)[,
.(date = seq(min(date), max(date) + horizon, by = "days"))
]

reported_cases <- data.table::merge.data.table(
reported_cases, reported_cases_grid,
by = "date", all.y = TRUE
)

if (is.null(reported_cases$breakpoint) && add_breakpoints) {
reported_cases$breakpoint <- 0
reported_cases <- add_horizon(data, horizon = horizon)
if (add_breakpoints) {
reported_cases <- add_breakpoints(reported_cases)
}
if (!is.null(reported_cases$breakpoint)) {
reported_cases[is.na(breakpoint), breakpoint := 0]
}
reported_cases <- data.table::setorder(reported_cases, date)
## Filter out 0 reported cases from the beginning of the data
if (filter_leading_zeros) {
reported_cases <- reported_cases[order(date)][
date >= min(date[confirm[!is.na(confirm)] > 0])
]
}
# Calculate `average_7_day` which for rows with `confirm == 0`
# (the only instance where this is being used) equates to the 7-day
# right-aligned moving average at the previous data point.
reported_cases <-
reported_cases[
,
`:=`(average_7_day = (
data.table::frollsum(confirm, n = 8, na.rm = TRUE)
) / 7
)
]
# Check case counts preceding zero case counts and set to 7 day average if
# average over last 7 days is greater than a threshold
if (!is.infinite(zero_threshold)) {
reported_cases <- reported_cases[
confirm == 0 & average_7_day > zero_threshold,
confirm := NA_integer_
]
}
reported_cases[is.na(confirm), confirm := fill]
reported_cases[, "average_7_day" := NULL]
## set accumulate to FALSE in added rows
if ("accumulate" %in% colnames(reported_cases)) {
reported_cases[is.na(accumulate), accumulate := FALSE]
reported_cases <- filter_leading_zeros(reported_cases)
}
return(reported_cases)
reported_cases <- apply_zero_threshold(reported_cases, zero_threshold)
return(reported_cases[])
}

#' Create complete cases
Expand Down
25 changes: 25 additions & 0 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,20 @@ estimate_infections <- function(data,
"estimate_infections(data)"
)
}
if (!missing(filter_leading_zeros)) {
lifecycle::deprecate_warn(
"1.7.0",
"estimate_infections(filter_leading_zeros)",
"filter_leading_zeros()"
)
}
if (!missing(zero_threshold)) {
lifecycle::deprecate_warn(
"1.7.0",
"estimate_infections(zero_threshold)",
"apply_zero_threshold()"
)
}
# Validate inputs
check_reports_valid(data, model = "estimate_infections")
assert_class(generation_time, "generation_time_opts")
Expand Down Expand Up @@ -184,6 +198,17 @@ estimate_infections <- function(data,
)
# Fill missing dates
reported_cases <- default_fill_missing_obs(data, obs, "confirm")
# Check initial zeros to check for deprecated filter zero functionality
if (reported_cases[date == min(date), "confirm"] == 0) {
cli_warn(c(
"!" = "Filtering initial zero observations in the data. This
functionality will be removed in future versions of EpiNow2. In order
to retain the default behaviour and filter initial zero observations
use the {.fn filter_leading_zeros()} function on the data before
calling {.fn estimate_infections()}."
))
}

# Create clean and complete cases
reported_cases <- create_clean_reported_cases(
reported_cases, horizon,
Expand Down
23 changes: 23 additions & 0 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ estimate_secondary <- function(data,
"estimate_secondary(data)"
)
}
if (!missing(filter_leading_zeros)) {
lifecycle::deprecate_warn(
"1.7.0",
"estimate_secondary(filter_leading_zeros)",
"filter_leading_zeros()"
)
}
if (!missing(zero_threshold)) {
lifecycle::deprecate_warn(
"1.7.0",
"estimate_secondary(zero_threshold)",
"apply_zero_threshold()"
)
}
# Validate the inputs
check_reports_valid(data, model = "estimate_secondary")
assert_class(secondary, "secondary_opts")
Expand Down Expand Up @@ -200,6 +214,15 @@ estimate_secondary <- function(data,

secondary_reports_dirty <-
reports[, list(date, confirm = secondary, accumulate)]
if (secondary_reports_dirty[date == min(date), "confirm"] == 0) {
cli_warn(
"!" = "Filtering initial zero observations in the data. This
functionality will be removed in future versions of EpiNow2. In order
to retain the default behaviour and filter initial zero observations
use the {.fn filter_leading_zeros()} function on the data before
calling {.fn estimate_secondary()."
)
}
secondary_reports <- create_clean_reported_cases(
secondary_reports_dirty,
filter_leading_zeros = filter_leading_zeros,
Expand Down
153 changes: 152 additions & 1 deletion R/preprocessing.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
##' using a data set that has multiple columns of hwich one of them
##' corresponds to observations that are to be processed here.
##' @param by Character vector. Name(s) of any additional column(s) where
##' missing data should be processed separately for each value in the column.
##' data processing should be done separately for each value in the column.
##' This is useful when using data representing e.g. multiple geographies. If
##' NULL (default) no such grouping is done.
##' @return a data.table with an `accumulate` column that indicates whether
Expand Down Expand Up @@ -177,3 +177,154 @@ default_fill_missing_obs <- function(data, obs, obs_column) {
}
return(data)
}

##' Add missing values for future dates
##'
##' @param accumulate The number of days to accumulate when generating posterior
##' prediction, e.g. 7 for weekly accumulated forecasts.
##' @inheritParams add_horizon
##' @inheritParams estimate_infections
##' @importFrom data.table copy merge.data.table setDT
##' @return A data.table with missing values for future dates
##' @keywords internal
add_horizon <- function(data, horizon, accumulate = 1L,
obs_column = "confirm", by = NULL) {
assert_data_frame(data)
assert_character(obs_column)
assert_character(by, null.ok = TRUE)
assert_names(
colnames(data),
must.include = c("date", by, obs_column)
)
assert_integerish(horizon, lower = 0)
assert_integerish(accumulate, lower = 1)
assert_date(data$date, any.missing = FALSE)

reported_cases <- data.table::setDT(data)
if (horizon > 0) {
reported_cases_grid <- data.table::copy(reported_cases)[,
.(date = seq(max(date) + 1, max(date) + horizon, by = "days")),
by = by
]
## if we accumulate add the column
if (accumulate > 1 || "accumulate" %in% colnames(data)) {
reported_cases_grid[, accumulate := TRUE]
## set accumulation to FALSE where appropriate
if (horizon >= accumulate) {
reported_cases_grid[
as.integer(date - min(date) - 1) %% accumulate == 0,
accumulate := FALSE
]
}
}
## fill any missing columns
reported_cases_grid <- data.table::merge.data.table(
reported_cases, reported_cases_grid,
by = "date", all.y = TRUE
)
}
return(reported_cases[])
}

##' Add breakpoints to certain dates in a data set.
##'
##' @param dates A vector of dates to use as breakpoints.
##' @inheritParams estimate_infections
##' @return A data.table with `breakpoint` set to 1 on each of the specified
##' dates.
##' @export
##' @importFrom data.table setDT
##' @examples
##' reported_cases <- add_breakpoints(example_confirmed, as.Date("2020-03-26"))
add_breakpoints <- function(data, dates = as.Date(character(0))) {
assert_data_frame(data)
assert_names(colnames(data), must.include = "date")
assert_date(dates)
assert_date(data$date, any.missing = FALSE)
reported_cases <- data.table::setDT(data)
if (is.null(reported_cases$breakpoint)) {
reported_cases$breakpoint <- 0
}
missing_dates <- setdiff(dates, data$date)
if (length(missing_dates) > 0) {
cli_abort("Breakpoint date{?s} not found in data: {.var {missing_dates}}")
}
reported_cases[date %in% dates, breakpoint := 1]
return(reported_cases)
}

##' Filter leading zeros from a data set.
##'
##' @inheritParams estimate_infections
##' @inheritParams fill_missing
##' @return A data.table with leading zeros removed.
##' @export
##' @importFrom data.table setDT
##' @examples
##' cases <- data.table(
##' date = as.Date("2020-01-01") + 0:10,
##' confirm = c(0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
##' )
##' filter_leading_zeros(cases)
filter_leading_zeros <- function(data, obs_column = "confirm", by = NULL) {
assert_data_frame(data)
assert_character(obs_column)
assert_character(by, null.ok = TRUE)
assert_names(
colnames(data),
must.include = c("date", by, obs_column)
)
reported_cases <- data.table::setDT(data)
reported_cases <- reported_cases[order(date)][
date >= min(date[get(obs_column)[!is.na(get(obs_column))] > 0])
]
return(reported_cases[])
}

##' Converts zero case counts to NA (missing) if the 7-day average is above a
##' threshold.
##'
##' This function aims to detect spurious zeroes by comparing the 7-day average
##' of the case counts to a threshold. If the 7-day average is above the
##' threshold, the zero case count is replaced with NA.
##'
##' @param threshold Numeric, defaults to Inf. Indicates if detected zero cases
##' are meaningful by using a threshold number of cases based on the 7-day
##' average. If the average is above this threshold at the time of a zero
##' observation count then the zero is replaced with a missing (`NA`) count
##' and thus ignored in the likelihood.
##'
##' @inheritParams estimate_infections
##' @inheritParams fill_missing
##' @importFrom data.table setDT frollsum
##' @return A data.table with the zero threshold applied.
##' @author Sebastian Funk
apply_zero_threshold <- function(data, threshold = Inf,
obs_column = "confirm") {
assert_data_frame(data)
assert_numeric(threshold)
reported_cases <- data.table::setDT(data)

# Calculate `average_7_day` which for rows with `confirm == 0`
# (the only instance where this is being used) equates to the 7-day
# right-aligned moving average at the previous data point.
reported_cases <-
reported_cases[
,
`:=`(average_7_day = (
data.table::frollsum(get(obs_column), n = 8, na.rm = TRUE)
) / 7
)
]
# Check case counts preceding zero case counts and set to 7 day average if
# average over last 7 days is greater than a threshold
if (!is.infinite(threshold)) {
reported_cases <- reported_cases[
get(obs_column) == 0 & average_7_day > threshold,
paste(obs_column) := NA_integer_
]
}
reported_cases[is.na(get(obs_column)), paste(obs_column) := NA_integer_]
reported_cases[, "average_7_day" := NULL]
return(reported_cases[])
}
7 changes: 7 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ reference:
contents:
- contains("_opts")
- opts_list
- title: Preprocess data
desc: Functions used for prepropcessing data
contents:
- fill_missing
- add_breakpoints
- filter_leading_zeros
- apply_zero_threshold
- title: Summarise Across Regions
desc: Functions used for summarising across regions (designed for use with regional_epinow)
contents:
Expand Down
30 changes: 30 additions & 0 deletions man/add_breakpoints.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit b12cb14

Please sign in to comment.