From a435a071c8b2d2454823f9a9a93360c05f933116 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Sat, 20 Jan 2024 09:19:56 +0000 Subject: [PATCH] make NA option work with estimate_secondary --- DESCRIPTION | 2 +- NAMESPACE | 1 + R/create.R | 32 ++++++++++++++++------ R/estimate_secondary.R | 18 +++++++++--- inst/stan/estimate_secondary.stan | 9 ++++-- inst/stan/functions/observation_model.stan | 16 ++++++----- man/create_clean_reported_cases.Rd | 2 +- man/create_complete_cases.Rd | 25 +++++++++++++++++ 8 files changed, 81 insertions(+), 24 deletions(-) create mode 100644 man/create_complete_cases.Rd diff --git a/DESCRIPTION b/DESCRIPTION index a675ac158..c60d0e338 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -149,7 +149,7 @@ Encoding: UTF-8 Language: en-GB LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.0 NeedsCompilation: yes SystemRequirements: GNU make C++17 diff --git a/NAMESPACE b/NAMESPACE index 3836bbb8c..07e858a65 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -128,6 +128,7 @@ importFrom(data.table,fwrite) importFrom(data.table,getDTthreads) importFrom(data.table,melt) importFrom(data.table,merge.data.table) +importFrom(data.table,nafill) importFrom(data.table,rbindlist) importFrom(data.table,setDT) importFrom(data.table,setDTthreads) diff --git a/R/create.R b/R/create.R index b4767547b..d000a576c 100644 --- a/R/create.R +++ b/R/create.R @@ -26,7 +26,7 @@ #' @export #' @examples #' create_clean_reported_cases(example_confirmed, 7) -create_clean_reported_cases <- function(reported_cases, horizon, +create_clean_reported_cases <- function(reported_cases, horizon = 0, filter_leading_zeros = TRUE, zero_threshold = Inf, fill = NA_integer_) { @@ -75,6 +75,25 @@ create_clean_reported_cases <- function(reported_cases, horizon, return(reported_cases) } +#' Create complete cases +#' @description `r lifecycle::badge("stable")` +#' Creates a complete data set without NA values and appropriate indices +#' +#' @param cases; data frame with a column "confirm" that may contain NA values +#' @param burn_in; integer (default 0). Number of days to remove from the +#' start of the time series be filtered out. +#' +#' @return A data frame without NA values, with two columns: confirm (number) +#' @author Sebastian Funk +#' @importFrom data.table setDT +#' @keywords internal +create_complete_cases <- function(cases) { + cases <- setDT(cases) + cases[, lookup := seq_len(.N)] + cases <- cases[!is.na(cases$confirm)] + return(cases[]) +} + #' Create Delay Shifted Cases #' #' @description `r lifecycle::badge("stable")` @@ -448,16 +467,13 @@ create_stan_data <- function(reported_cases, seeding_time, backcalc, shifted_cases) { cases <- reported_cases[(seeding_time + 1):(.N - horizon)] - cases[, lookup := seq_len(.N)] - complete_cases <- cases[!is.na(cases$confirm)] - cases_time <- complete_cases$lookup - complete_cases <- complete_cases$confirm + complete_cases <- create_complete_cases(cases) cases <- cases$confirm data <- list( - cases = complete_cases, - cases_time = cases_time, - lt = length(cases_time), + cases = complete_cases$confirm, + cases_time = complete_cases$lookup, + lt = nrow(complete_cases), shifted_cases = shifted_cases, t = length(reported_cases$date), horizon = horizon, diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index fecbf8b4b..7936ed2ef 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -64,7 +64,7 @@ #' @inheritParams calc_CrIs #' @importFrom rstan sampling #' @importFrom lubridate wday -#' @importFrom data.table as.data.table merge.data.table +#' @importFrom data.table as.data.table merge.data.table nafill #' @importFrom utils modifyList #' @importFrom checkmate assert_class assert_numeric assert_data_frame #' assert_logical @@ -165,6 +165,15 @@ estimate_secondary <- function(reports, assert_logical(verbose) reports <- data.table::as.data.table(reports) + secondary_reports <- reports[, list(date, confirm = secondary)] + secondary_reports <- create_clean_reported_cases(secondary_reports) + ## fill in missing data (required if fitting to prevalence) + complete_secondary <- create_complete_cases(secondary_reports) + + ## fill down + secondary_reports[, confirm := nafill(confirm, type = "locf")] + ## fill any early data up + secondary_reports[, confirm := nafill(confirm, type = "nocb")] if (burn_in >= nrow(reports)) { stop("burn_in is greater or equal to the number of observations. @@ -173,9 +182,10 @@ estimate_secondary <- function(reports, # observation and control data data <- list( t = nrow(reports), - obs = reports$secondary, - obs_time = seq_along(reports$secondary), primary = reports$primary, + obs = secondary_reports$confirm, + obs_time = complete_secondary[lookup > burn_in]$lookup - burn_in, + lt = sum(complete_secondary$lookup > burn_in), burn_in = burn_in, seeding_time = 0 ) @@ -396,7 +406,7 @@ plot.estimate_secondary <- function(x, primary = FALSE, from = NULL, to = NULL, new_obs = NULL, ...) { - predictions <- data.table::copy(x$predictions) + predictions <- data.table::copy(x$predictions)[!is.na(secondary)] if (!is.null(new_obs)) { new_obs <- data.table::as.data.table(new_obs) diff --git a/inst/stan/estimate_secondary.stan b/inst/stan/estimate_secondary.stan index 6cd359a7c..77507a9ed 100644 --- a/inst/stan/estimate_secondary.stan +++ b/inst/stan/estimate_secondary.stan @@ -8,8 +8,9 @@ functions { data { int t; // time of observations + int lt; // time of observations array[t] int obs; // observed secondary data - array[t] int obs_time; // observed secondary data + array[lt] int obs_time; // observed secondary data vector[t] primary; // observed primary data int burn_in; // time period to not use for fitting #include data/secondary.stan @@ -84,8 +85,10 @@ model { } // observed secondary reports from mean of secondary reports (update likelihood) if (likelihood) { - report_lp(obs[(burn_in + 1):t], obs_time, secondary[(burn_in + 1):t], - rep_phi, phi_mean, phi_sd, model_type, 1, accumulate); + report_lp( + obs[(burn_in + 1):t][obs_time], obs_time, secondary[(burn_in + 1):t], + rep_phi, phi_mean, phi_sd, model_type, 1, accumulate + ); } } diff --git a/inst/stan/functions/observation_model.stan b/inst/stan/functions/observation_model.stan index ec5822d3b..6535c07f7 100644 --- a/inst/stan/functions/observation_model.stan +++ b/inst/stan/functions/observation_model.stan @@ -54,28 +54,30 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd, void report_lp(array[] int cases, array[] int cases_time, vector reports, array[] real rep_phi, real phi_mean, real phi_sd, int model_type, real weight, int accumulate) { - int n = num_elements(cases) - accumulate; // number of observations + int n = num_elements(cases_time) - accumulate; // number of observations vector[n] obs_reports; // reports at observation time array[n] int obs_cases; // observed cases at observation time if (accumulate) { int t = num_elements(reports); + int i = 0; int current_obs = 0; obs_reports = rep_vector(0, n); - for (i in 1:t) { - if (current_obs > 0) { // first observation gets ignored when acucmulating + while (i <= t && current_obs <= n) { + if (current_obs > 0) { // first observation gets ignored when accumulating obs_reports[current_obs] += reports[i]; } - if (i == cases_time[current_obs]) { + if (i == cases_time[current_obs + 1]) { current_obs += 1; } + i += 1; } - obs_cases = cases[2:(n - 1)]; + obs_cases = cases[2:(n + 1)]; } else { obs_reports = reports[cases_time]; obs_cases = cases; } if (model_type) { - real dispersion = 1 / pow(rep_phi[model_type], 2); + real dispersion = 1 / pow(rep_phi[model_type], 2); rep_phi[model_type] ~ normal(phi_mean, phi_sd) T[0,]; if (weight == 1) { obs_cases ~ neg_binomial_2(obs_reports, dispersion); @@ -119,7 +121,7 @@ array[] int report_rng(vector reports, array[] real rep_phi, int model_type) { if (model_type) { dispersion = 1 / pow(rep_phi[model_type], 2); } - + for (s in 1:t) { if (reports[s] < 1e-8) { sampled_reports[s] = 0; diff --git a/man/create_clean_reported_cases.Rd b/man/create_clean_reported_cases.Rd index c53830c0c..5daf45877 100644 --- a/man/create_clean_reported_cases.Rd +++ b/man/create_clean_reported_cases.Rd @@ -6,7 +6,7 @@ \usage{ create_clean_reported_cases( reported_cases, - horizon, + horizon = 0, filter_leading_zeros = TRUE, zero_threshold = Inf, fill = NA_integer_ diff --git a/man/create_complete_cases.Rd b/man/create_complete_cases.Rd new file mode 100644 index 000000000..79eb108c6 --- /dev/null +++ b/man/create_complete_cases.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/create.R +\name{create_complete_cases} +\alias{create_complete_cases} +\title{Create complete cases} +\usage{ +create_complete_cases(cases) +} +\arguments{ +\item{cases;}{data frame with a column "confirm" that may contain NA values} + +\item{burn_in;}{integer (default 0). Number of days to remove from the +start of the time series be filtered out.} +} +\value{ +A data frame without NA values, with two columns: confirm (number) +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} +Creates a complete data set without NA values and appropriate indices +} +\author{ +Sebastian Funk +} +\keyword{internal}