Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

break up create_clean_reported_cases() #884

Merged
merged 20 commits into from
Dec 10, 2024
Merged
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export(LogNormal)
export(NonParametric)
export(Normal)
export(R_to_growth)
export(add_breakpoints)
export(adjust_infection_to_report)
export(apply_tolerance)
export(backcalc_opts)
Expand Down Expand Up @@ -64,6 +65,7 @@ export(extract_inits)
export(extract_samples)
export(extract_stan_param)
export(fill_missing)
export(filter_leading_zeros)
export(fix_dist)
export(fix_parameters)
export(forecast_infections)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and reviewed by @seabbs.

## Package changes

- The internal functions `create_clean_reported_cases()` has been broken up into several functions, with relevant ones `filter_leading_zeros()`, `add_breakpoints()` and `apply_zero_threshold()` exposed to the user. By @sbfnk in #884 and reviewed by @seabbs and @jamesmbaazam.

## Documentation

- Brought the docs on `alpha_sd` up to date with the code change from prior PR #853. By @zsusswein in #862 and reviewed by @jamesmbaazam.
Expand Down
62 changes: 9 additions & 53 deletions R/create.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#' Create Clean Reported Cases
#' @description `r lifecycle::badge("stable")`
#' @description `r lifecycle::badge("deprecated")`
#' Filters leading zeros, completes dates, and applies an optional threshold at
#' which point 0 cases are replaced with a user supplied value (defaults to
#' `NA`).
Expand All @@ -12,16 +12,12 @@
#' number of cases based on the 7-day average. If the average is above this
#' threshold then the zero is replaced using `fill`.
#'
#' @param fill Numeric, defaults to NA. Value to use to replace NA values or
#' zeroes that are flagged because the 7-day average is above the
#' `zero_threshold`. If the default NA is used then dates with NA values or with
#' 7-day averages above the `zero_threshold` will be skipped in model fitting.
#' If this is set to 0 then the only effect is to replace NA values with 0.
#' @param fill Deprecated; zero dates with 7-day averages above the
#' `zero_threshold` will be skipped in model fitting.
#' @param add_breakpoints Logical, defaults to TRUE. Should a breakpoint column
#' be added to the data frame if it does not exist.
#'
#' @inheritParams estimate_infections
#' @importFrom data.table copy merge.data.table setorder setDT frollsum
#' @return A cleaned data frame of reported cases
#' @keywords internal
#' @examples
Expand All @@ -33,55 +29,15 @@ create_clean_reported_cases <- function(data, horizon = 0,
zero_threshold = Inf,
fill = NA_integer_,
add_breakpoints = TRUE) {
reported_cases <- data.table::setDT(data)
reported_cases_grid <- data.table::copy(reported_cases)[,
.(date = seq(min(date), max(date) + horizon, by = "days"))
]

reported_cases <- data.table::merge.data.table(
reported_cases, reported_cases_grid,
by = "date", all.y = TRUE
)

if (is.null(reported_cases$breakpoint) && add_breakpoints) {
reported_cases$breakpoint <- 0
reported_cases <- add_horizon(data, horizon = horizon)
if (add_breakpoints) {
reported_cases <- add_breakpoints(reported_cases)
}
if (!is.null(reported_cases$breakpoint)) {
reported_cases[is.na(breakpoint), breakpoint := 0]
}
reported_cases <- data.table::setorder(reported_cases, date)
## Filter out 0 reported cases from the beginning of the data
if (filter_leading_zeros) {
reported_cases <- reported_cases[order(date)][
date >= min(date[confirm[!is.na(confirm)] > 0])
]
}
# Calculate `average_7_day` which for rows with `confirm == 0`
# (the only instance where this is being used) equates to the 7-day
# right-aligned moving average at the previous data point.
reported_cases <-
reported_cases[
,
`:=`(average_7_day = (
data.table::frollsum(confirm, n = 8, na.rm = TRUE)
) / 7
)
]
# Check case counts preceding zero case counts and set to 7 day average if
# average over last 7 days is greater than a threshold
if (!is.infinite(zero_threshold)) {
reported_cases <- reported_cases[
confirm == 0 & average_7_day > zero_threshold,
confirm := NA_integer_
]
}
reported_cases[is.na(confirm), confirm := fill]
reported_cases[, "average_7_day" := NULL]
## set accumulate to FALSE in added rows
if ("accumulate" %in% colnames(reported_cases)) {
reported_cases[is.na(accumulate), accumulate := FALSE]
reported_cases <- filter_leading_zeros(reported_cases)
}
return(reported_cases)
reported_cases <- apply_zero_threshold(reported_cases, zero_threshold)
return(reported_cases[])
}

#' Create complete cases
Expand Down
27 changes: 27 additions & 0 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,20 @@ estimate_infections <- function(data,
"estimate_infections(data)"
)
}
if (!missing(filter_leading_zeros)) {
lifecycle::deprecate_warn(
"1.7.0",
"estimate_infections(filter_leading_zeros)",
"filter_leading_zeros()"
)
}
if (!missing(zero_threshold)) {
lifecycle::deprecate_warn(
"1.7.0",
"estimate_infections(zero_threshold)",
"apply_zero_threshold()"
)
}
# Validate inputs
check_reports_valid(data, model = "estimate_infections")
assert_class(generation_time, "generation_time_opts")
Expand Down Expand Up @@ -184,6 +198,19 @@ estimate_infections <- function(data,
)
# Fill missing dates
reported_cases <- default_fill_missing_obs(data, obs, "confirm")
# Check initial zeros to check for deprecated filter zero functionality
if (filter_leading_zeros &&
!is.na(reported_cases[date == min(date), "confirm"]) &&
reported_cases[date == min(date), "confirm"] == 0) {
cli_warn(c(
"!" = "Filtering initial zero observations in the data. This
functionality will be removed in future versions of EpiNow2. In order
to retain the default behaviour and filter initial zero observations
use the {.fn filter_leading_zeros()} function on the data before
calling {.fn estimate_infections()}."
))
}

# Create clean and complete cases
reported_cases <- create_clean_reported_cases(
reported_cases, horizon,
Expand Down
24 changes: 24 additions & 0 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ estimate_secondary <- function(data,
"estimate_secondary(data)"
)
}
if (!missing(filter_leading_zeros)) {
lifecycle::deprecate_warn(
"1.7.0",
"estimate_secondary(filter_leading_zeros)",
"filter_leading_zeros()"
)
}
if (!missing(zero_threshold)) {
lifecycle::deprecate_warn(
"1.7.0",
"estimate_secondary(zero_threshold)",
"apply_zero_threshold()"
)
}
# Validate the inputs
check_reports_valid(data, model = "estimate_secondary")
assert_class(secondary, "secondary_opts")
Expand Down Expand Up @@ -200,6 +214,16 @@ estimate_secondary <- function(data,

secondary_reports_dirty <-
reports[, list(date, confirm = secondary, accumulate)]
if (filter_leading_zeros &&
!is.na(secondary_reports_dirty[date == min(date), "confirm"]) &&
secondary_reports_dirty[date == min(date), "confirm"] == 0) {
cli_warn(c(
"!" = "Filtering initial zero observations in the data. This
functionality will be removed in future versions of EpiNow2. In order
to filter initial zero observations use the {.fn filter_leading_zeros}
function on the data before calling {.fn estimate_secondary}."
))
}
secondary_reports <- create_clean_reported_cases(
secondary_reports_dirty,
filter_leading_zeros = filter_leading_zeros,
Expand Down
157 changes: 156 additions & 1 deletion R/preprocessing.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
##' using a data set that has multiple columns of hwich one of them
##' corresponds to observations that are to be processed here.
##' @param by Character vector. Name(s) of any additional column(s) where
##' missing data should be processed separately for each value in the column.
##' data processing should be done separately for each value in the column.
##' This is useful when using data representing e.g. multiple geographies. If
##' NULL (default) no such grouping is done.
##' @return a data.table with an `accumulate` column that indicates whether
Expand Down Expand Up @@ -177,3 +177,158 @@ default_fill_missing_obs <- function(data, obs, obs_column) {
}
return(data)
}

##' Add missing values for future dates
##'
##' @param data Data frame with a `date` column. The other columns depend on the
##' model that the data are to be used, e.g. [estimate_infections()] or
##' [estimate_secondary()]. See the documentation there for the expected
##' format.
##' @param accumulate The number of days to accumulate when generating posterior
##' prediction, e.g. 7 for weekly accumulated forecasts.
##' @inheritParams fill_missing
##' @inheritParams estimate_infections
##' @importFrom data.table copy merge.data.table setDT
##' @return A data.table with missing values for future dates
##' @keywords internal
add_horizon <- function(data, horizon, accumulate = 1L,
obs_column = "confirm", by = NULL) {
assert_data_frame(data)
assert_character(obs_column)
assert_character(by, null.ok = TRUE)
assert_names(
colnames(data),
must.include = c("date", by, obs_column)
)
assert_integerish(horizon, lower = 0)
assert_integerish(accumulate, lower = 1)
assert_date(data$date, any.missing = FALSE)

reported_cases <- data.table::setDT(data)
if (horizon > 0) {
reported_cases_future <- data.table::copy(reported_cases)[,
.(date = seq(max(date) + 1, max(date) + horizon, by = "days")),
by = by
]
## if we accumulate add the column
if (accumulate > 1 || "accumulate" %in% colnames(data)) {
reported_cases_future[, accumulate := TRUE]
## set accumulation to FALSE where appropriate
if (horizon >= accumulate) {
reported_cases_future[
as.integer(date - min(date) - 1) %% accumulate == 0,
accumulate := FALSE
]
}
}
## fill any missing columns
reported_cases <- rbind(
reported_cases, reported_cases_future,
fill = TRUE
)
}
return(reported_cases[])
}

##' Add breakpoints to certain dates in a data set.
##'
##' @param dates A vector of dates to use as breakpoints.
##' @inheritParams estimate_infections
##' @return A data.table with `breakpoint` set to 1 on each of the specified
##' dates.
##' @export
##' @importFrom data.table setDT
##' @examples
##' reported_cases <- add_breakpoints(example_confirmed, as.Date("2020-03-26"))
add_breakpoints <- function(data, dates = as.Date(character(0))) {
assert_data_frame(data)
assert_names(colnames(data), must.include = "date")
assert_date(dates)
assert_date(data$date, any.missing = FALSE)
reported_cases <- data.table::setDT(data)
if (is.null(reported_cases$breakpoint)) {
reported_cases$breakpoint <- 0
}
missing_dates <- setdiff(dates, data$date)
if (length(missing_dates) > 0) {
cli_abort("Breakpoint date{?s} not found in data: {.var {missing_dates}}")
}
reported_cases[date %in% dates, breakpoint := 1]
reported_cases[is.na(breakpoint), breakpoint := 0]
return(reported_cases)
}

##' Filter leading zeros from a data set.
##'
##' @inheritParams estimate_infections
##' @inheritParams fill_missing
##' @return A data.table with leading zeros removed.
##' @export
##' @importFrom data.table setDT
##' @examples
##' cases <- data.frame(
##' date = as.Date("2020-01-01") + 0:10,
##' confirm = c(0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
##' )
##' filter_leading_zeros(cases)
filter_leading_zeros <- function(data, obs_column = "confirm", by = NULL) {
assert_data_frame(data)
assert_character(obs_column)
assert_character(by, null.ok = TRUE)
assert_names(
colnames(data),
must.include = c("date", by, obs_column)
)
reported_cases <- data.table::setDT(data)
reported_cases <- reported_cases[order(date)][
date >= min(date[get(obs_column)[!is.na(get(obs_column))] > 0])
]
return(reported_cases[])
}

##' Convert zero case counts to `NA` (missing) if the 7-day average is above a
##' threshold.
##'
##' This function aims to detect spurious zeroes by comparing the 7-day average
##' of the case counts to a threshold. If the 7-day average is above the
##' threshold, the zero case count is replaced with `NA`.
##'
##' @param threshold Numeric, defaults to `Inf`. Indicates if detected zero
##' cases are meaningful by using a threshold number of cases based on the
##' 7-day average. If the average is above this threshold at the time of a
##' zero observation count then the zero is replaced with a missing (`NA`)
##' count and thus ignored in the likelihood.
##'
##' @inheritParams estimate_infections
##' @inheritParams fill_missing
##' @importFrom data.table setDT frollsum
##' @return A data.table with the zero threshold applied.
apply_zero_threshold <- function(data, threshold = Inf,
obs_column = "confirm") {
assert_data_frame(data)
assert_numeric(threshold)
reported_cases <- data.table::setDT(data)

# Calculate `average_7_day` which for rows with `confirm == 0`
# (the only instance where this is being used) equates to the 7-day
# right-aligned moving average at the previous data point.
reported_cases <-
reported_cases[
,
`:=`(average_7_day = (
data.table::frollsum(get(obs_column), n = 8, na.rm = TRUE)
) / 7
)
]
# Check case counts preceding zero case counts and set to 7 day average if
# average over last 7 days is greater than a threshold
if (!is.infinite(threshold)) {
reported_cases <- reported_cases[
get(obs_column) == 0 & average_7_day > threshold,
paste(obs_column) := NA_integer_
]
}
reported_cases[is.na(get(obs_column)), paste(obs_column) := NA_integer_]
reported_cases[, "average_7_day" := NULL]
return(reported_cases[])
}
2 changes: 1 addition & 1 deletion R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,6 @@ globalVariables(
"..lowers", "..upper_CrI", "..uppers", "timing", "dataset", "last_confirm",
"report_date", "secondary", "id", "conv", "meanlog", "primary", "scaled",
"scaling", "sdlog", "lookup", "new_draw", ".draw", "p", "distribution",
"accumulate", "..present"
"accumulate", "..present", "reported_cases"
)
)
7 changes: 7 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ reference:
contents:
- contains("_opts")
- opts_list
- title: Preprocess data
desc: Functions used for prepropcessing data
contents:
- fill_missing
- add_breakpoints
- filter_leading_zeros
- apply_zero_threshold
- title: Summarise Across Regions
desc: Functions used for summarising across regions (designed for use with regional_epinow)
contents:
Expand Down
Loading
Loading