Skip to content

Commit

Permalink
make NA option work with estimate_secondary
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Jan 22, 2024
1 parent 4d186fc commit a435a07
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 24 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Encoding: UTF-8
Language: en-GB
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.0
NeedsCompilation: yes
SystemRequirements: GNU make
C++17
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ importFrom(data.table,fwrite)
importFrom(data.table,getDTthreads)
importFrom(data.table,melt)
importFrom(data.table,merge.data.table)
importFrom(data.table,nafill)
importFrom(data.table,rbindlist)
importFrom(data.table,setDT)
importFrom(data.table,setDTthreads)
Expand Down
32 changes: 24 additions & 8 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#' @export
#' @examples
#' create_clean_reported_cases(example_confirmed, 7)
create_clean_reported_cases <- function(reported_cases, horizon,
create_clean_reported_cases <- function(reported_cases, horizon = 0,
filter_leading_zeros = TRUE,
zero_threshold = Inf,
fill = NA_integer_) {
Expand Down Expand Up @@ -75,6 +75,25 @@ create_clean_reported_cases <- function(reported_cases, horizon,
return(reported_cases)
}

#' Create complete cases
#' @description `r lifecycle::badge("stable")`
#' Creates a complete data set without NA values and appropriate indices
#'
#' @param cases; data frame with a column "confirm" that may contain NA values
#' @param burn_in; integer (default 0). Number of days to remove from the
#' start of the time series be filtered out.
#'
#' @return A data frame without NA values, with two columns: confirm (number)
#' @author Sebastian Funk
#' @importFrom data.table setDT
#' @keywords internal
create_complete_cases <- function(cases) {
cases <- setDT(cases)
cases[, lookup := seq_len(.N)]
cases <- cases[!is.na(cases$confirm)]
return(cases[])
}

#' Create Delay Shifted Cases
#'
#' @description `r lifecycle::badge("stable")`
Expand Down Expand Up @@ -448,16 +467,13 @@ create_stan_data <- function(reported_cases, seeding_time,
backcalc, shifted_cases) {

cases <- reported_cases[(seeding_time + 1):(.N - horizon)]
cases[, lookup := seq_len(.N)]
complete_cases <- cases[!is.na(cases$confirm)]
cases_time <- complete_cases$lookup
complete_cases <- complete_cases$confirm
complete_cases <- create_complete_cases(cases)
cases <- cases$confirm

data <- list(
cases = complete_cases,
cases_time = cases_time,
lt = length(cases_time),
cases = complete_cases$confirm,
cases_time = complete_cases$lookup,
lt = nrow(complete_cases),
shifted_cases = shifted_cases,
t = length(reported_cases$date),
horizon = horizon,
Expand Down
18 changes: 14 additions & 4 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
#' @inheritParams calc_CrIs
#' @importFrom rstan sampling
#' @importFrom lubridate wday
#' @importFrom data.table as.data.table merge.data.table
#' @importFrom data.table as.data.table merge.data.table nafill
#' @importFrom utils modifyList
#' @importFrom checkmate assert_class assert_numeric assert_data_frame
#' assert_logical
Expand Down Expand Up @@ -165,6 +165,15 @@ estimate_secondary <- function(reports,
assert_logical(verbose)

reports <- data.table::as.data.table(reports)
secondary_reports <- reports[, list(date, confirm = secondary)]
secondary_reports <- create_clean_reported_cases(secondary_reports)
## fill in missing data (required if fitting to prevalence)
complete_secondary <- create_complete_cases(secondary_reports)

## fill down
secondary_reports[, confirm := nafill(confirm, type = "locf")]
## fill any early data up
secondary_reports[, confirm := nafill(confirm, type = "nocb")]

if (burn_in >= nrow(reports)) {
stop("burn_in is greater or equal to the number of observations.
Expand All @@ -173,9 +182,10 @@ estimate_secondary <- function(reports,
# observation and control data
data <- list(
t = nrow(reports),
obs = reports$secondary,
obs_time = seq_along(reports$secondary),
primary = reports$primary,
obs = secondary_reports$confirm,
obs_time = complete_secondary[lookup > burn_in]$lookup - burn_in,
lt = sum(complete_secondary$lookup > burn_in),
burn_in = burn_in,
seeding_time = 0
)
Expand Down Expand Up @@ -396,7 +406,7 @@ plot.estimate_secondary <- function(x, primary = FALSE,
from = NULL, to = NULL,
new_obs = NULL,
...) {
predictions <- data.table::copy(x$predictions)
predictions <- data.table::copy(x$predictions)[!is.na(secondary)]

if (!is.null(new_obs)) {
new_obs <- data.table::as.data.table(new_obs)
Expand Down
9 changes: 6 additions & 3 deletions inst/stan/estimate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ functions {

data {
int t; // time of observations
int lt; // time of observations
array[t] int<lower = 0> obs; // observed secondary data
array[t] int obs_time; // observed secondary data
array[lt] int obs_time; // observed secondary data
vector[t] primary; // observed primary data
int burn_in; // time period to not use for fitting
#include data/secondary.stan
Expand Down Expand Up @@ -84,8 +85,10 @@ model {
}
// observed secondary reports from mean of secondary reports (update likelihood)
if (likelihood) {
report_lp(obs[(burn_in + 1):t], obs_time, secondary[(burn_in + 1):t],
rep_phi, phi_mean, phi_sd, model_type, 1, accumulate);
report_lp(
obs[(burn_in + 1):t][obs_time], obs_time, secondary[(burn_in + 1):t],
rep_phi, phi_mean, phi_sd, model_type, 1, accumulate
);
}
}

Expand Down
16 changes: 9 additions & 7 deletions inst/stan/functions/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -54,28 +54,30 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd,
void report_lp(array[] int cases, array[] int cases_time, vector reports,
array[] real rep_phi, real phi_mean, real phi_sd,
int model_type, real weight, int accumulate) {
int n = num_elements(cases) - accumulate; // number of observations
int n = num_elements(cases_time) - accumulate; // number of observations
vector[n] obs_reports; // reports at observation time
array[n] int obs_cases; // observed cases at observation time
if (accumulate) {
int t = num_elements(reports);
int i = 0;
int current_obs = 0;
obs_reports = rep_vector(0, n);
for (i in 1:t) {
if (current_obs > 0) { // first observation gets ignored when acucmulating
while (i <= t && current_obs <= n) {
if (current_obs > 0) { // first observation gets ignored when accumulating
obs_reports[current_obs] += reports[i];
}
if (i == cases_time[current_obs]) {
if (i == cases_time[current_obs + 1]) {
current_obs += 1;
}
i += 1;
}
obs_cases = cases[2:(n - 1)];
obs_cases = cases[2:(n + 1)];
} else {
obs_reports = reports[cases_time];
obs_cases = cases;
}
if (model_type) {
real dispersion = 1 / pow(rep_phi[model_type], 2);
real dispersion = 1 / pow(rep_phi[model_type], 2);
rep_phi[model_type] ~ normal(phi_mean, phi_sd) T[0,];
if (weight == 1) {
obs_cases ~ neg_binomial_2(obs_reports, dispersion);
Expand Down Expand Up @@ -119,7 +121,7 @@ array[] int report_rng(vector reports, array[] real rep_phi, int model_type) {
if (model_type) {
dispersion = 1 / pow(rep_phi[model_type], 2);
}

for (s in 1:t) {
if (reports[s] < 1e-8) {
sampled_reports[s] = 0;
Expand Down
2 changes: 1 addition & 1 deletion man/create_clean_reported_cases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/create_complete_cases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a435a07

Please sign in to comment.