diff --git a/NEWS.md b/NEWS.md index 55e00eea7..c68a5bc48 100644 --- a/NEWS.md +++ b/NEWS.md @@ -16,6 +16,7 @@ - a bug was fixed that caused delay option functions to report an error if only the tolerance was specified. By @sbfnk in #716 and reviewed by @jamesmbaazam. - a bug was fixed where `forecast_secondary()` did not work with fixed delays. By @sbfnk in #717 and reviewed by @seabbs. - a bug was fixed that caused delay option functions to report an error if only the tolerance was specified. By @sbfnk. +- a bug was fixed that led to the truncation PMF being shortened from the wrong side when the truncation PMF was longer than the supplied data. By @seabbs in #736 and reviewed by @sbfnk and @jamesmbaazam. ## Documentation diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 4bd1b8fb2..0ba64cfd5 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -143,7 +143,7 @@ transformed parameters { ); } profile("truncate") { - obs_reports = truncate(reports[1:ot], trunc_rev_cmf, 0); + obs_reports = truncate_obs(reports[1:ot], trunc_rev_cmf, 0); } } else { obs_reports = reports[1:ot]; diff --git a/inst/stan/estimate_secondary.stan b/inst/stan/estimate_secondary.stan index 70fcc8d4a..8f5081fb0 100644 --- a/inst/stan/estimate_secondary.stan +++ b/inst/stan/estimate_secondary.stan @@ -78,7 +78,7 @@ transformed parameters { delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, 0, 1, 1 ); - secondary = truncate(secondary, trunc_rev_cmf, 0); + secondary = truncate_obs(secondary, trunc_rev_cmf, 0); } } diff --git a/inst/stan/estimate_truncation.stan b/inst/stan/estimate_truncation.stan index 3432c094f..aab08f481 100644 --- a/inst/stan/estimate_truncation.stan +++ b/inst/stan/estimate_truncation.stan @@ -47,12 +47,12 @@ transformed parameters{ vector[t] last_obs; // reconstruct latest data without truncation - last_obs = truncate(to_vector(obs[, obs_sets]), trunc_rev_cmf, 1); + last_obs = truncate_obs(to_vector(obs[, obs_sets]), trunc_rev_cmf, 1); // apply truncation to latest dataset to map back to previous data sets and // add noise term for (i in 1:(obs_sets - 1)) { trunc_obs[1:(end_t[i] - start_t[i] + 1), i] = - truncate(last_obs[start_t[i]:end_t[i]], trunc_rev_cmf, 0) + sigma; + truncate_obs(last_obs[start_t[i]:end_t[i]], trunc_rev_cmf, 0) + sigma; } } } @@ -80,7 +80,7 @@ generated quantities { matrix[delay_type_max[trunc_id] + 1, obs_sets - 1] gen_obs; // reconstruct all truncated datasets using posterior of the truncation distribution for (i in 1:obs_sets) { - recon_obs[1:(end_t[i] - start_t[i] + 1), i] = truncate( + recon_obs[1:(end_t[i] - start_t[i] + 1), i] = truncate_obs( to_vector(obs[start_t[i]:end_t[i], i]), trunc_rev_cmf, 1 ); } diff --git a/inst/stan/functions/observation_model.stan b/inst/stan/functions/observation_model.stan index ac54496c7..f55210281 100644 --- a/inst/stan/functions/observation_model.stan +++ b/inst/stan/functions/observation_model.stan @@ -1,4 +1,14 @@ -// apply day of the week effect +/** + * Apply day of the week effect to reports + * + * This function applies a day of the week effect to a vector of reports. + * + * @param reports Vector of reports to be adjusted. + * @param day_of_week Array of integers representing the day of the week for each report. + * @param effect Vector of day of week effects. + * + * @return A vector of reports adjusted for day of the week effects. + */ vector day_of_week_effect(vector reports, array[] int day_of_week, vector effect) { int t = num_elements(reports); int wl = num_elements(effect); @@ -11,30 +21,65 @@ vector day_of_week_effect(vector reports, array[] int day_of_week, vector effect } return(scaled_reports); } -// Scale observations by fraction reported and update log density of -// fraction reported + +/** + * Scale observations by fraction reported + * + * This function scales a vector of reports by a fraction observed. + * + * @param reports Vector of reports to be scaled. + * @param frac_obs Real value representing the fraction observed. + * + * @return A vector of scaled reports. + */ vector scale_obs(vector reports, real frac_obs) { int t = num_elements(reports); vector[t] scaled_reports; scaled_reports = reports * frac_obs; return(scaled_reports); } -// Truncate observed data by some truncation distribution -vector truncate(vector reports, vector trunc_rev_cmf, int reconstruct) { + +/** + * Truncate observed data by a truncation distribution + * + * This function truncates a vector of reports based on a truncation distribution. + * + * @param reports Vector of reports to be truncated. + * @param trunc_rev_cmf Vector representing the reverse cumulative mass function of the truncation distribution. + * @param reconstruct Integer flag indicating whether to reconstruct (1) or truncate (0) the data. + * + * @return A vector of truncated reports. + */ +vector truncate_obs(vector reports, vector trunc_rev_cmf, int reconstruct) { int t = num_elements(reports); + int trunc_max = num_elements(trunc_rev_cmf); vector[t] trunc_reports = reports; // Calculate cmf of truncation delay - int trunc_max = min(t, num_elements(trunc_rev_cmf)); - int first_t = t - trunc_max + 1; + int joint_max = min(t, trunc_max); + int first_t = t - joint_max + 1; + int first_trunc = trunc_max - joint_max + 1; + // Apply cdf of truncation delay to truncation max last entries in reports if (reconstruct) { - trunc_reports[first_t:t] ./= trunc_rev_cmf[1:trunc_max]; + trunc_reports[first_t:t] ./= trunc_rev_cmf[first_trunc:trunc_max]; } else { - trunc_reports[first_t:t] .*= trunc_rev_cmf[1:trunc_max]; + trunc_reports[first_t:t] .*= trunc_rev_cmf[first_trunc:trunc_max]; } return(trunc_reports); } -// Truncation distribution priors + +/** + * Update log density for truncation distribution priors + * + * This function updates the log density for truncation distribution priors. + * + * @param truncation_mean Array of real values for truncation mean. + * @param truncation_sd Array of real values for truncation standard deviation. + * @param trunc_mean_mean Array of real values for mean of truncation mean prior. + * @param trunc_mean_sd Array of real values for standard deviation of truncation mean prior. + * @param trunc_sd_mean Array of real values for mean of truncation standard deviation prior. + * @param trunc_sd_sd Array of real values for standard deviation of truncation standard deviation prior. + */ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd, array[] real trunc_mean_mean, array[] real trunc_mean_sd, array[] real trunc_sd_mean, array[] real trunc_sd_sd) { @@ -50,7 +95,22 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd, } } } -// update log density for reported cases + +/** + * Update log density for reported cases + * + * This function updates the log density for reported cases based on the specified model type. + * + * @param cases Array of integer observed cases. + * @param cases_time Array of integer time indices for observed cases. + * @param reports Vector of expected reports. + * @param rep_phi Array of real values for reporting overdispersion. + * @param phi_mean Real value for mean of reporting overdispersion prior. + * @param phi_sd Real value for standard deviation of reporting overdispersion prior. + * @param model_type Integer indicating the model type (0 for Poisson, >0 for Negative Binomial). + * @param weight Real value for weighting the log density contribution. + * @param accumulate Integer flag indicating whether to accumulate reports (1) or not (0). + */ void report_lp(array[] int cases, array[] int cases_time, vector reports, array[] real rep_phi, real phi_mean, real phi_sd, int model_type, real weight, int accumulate) { @@ -96,7 +156,20 @@ void report_lp(array[] int cases, array[] int cases_time, vector reports, } } } -// update log likelihood (as above but not vectorised and returning log likelihood) + +/** + * Calculate log likelihood for reported cases + * + * This function calculates the log likelihood for reported cases based on the specified model type. + * + * @param cases Array of integer observed cases. + * @param reports Vector of expected reports. + * @param rep_phi Array of real values for reporting overdispersion. + * @param model_type Integer indicating the model type (0 for Poisson, >0 for Negative Binomial). + * @param weight Real value for weighting the log likelihood contribution. + * + * @return A vector of log likelihoods for each time point. + */ vector report_log_lik(array[] int cases, vector reports, array[] real rep_phi, int model_type, real weight) { int t = num_elements(reports); @@ -115,7 +188,18 @@ vector report_log_lik(array[] int cases, vector reports, } return(log_lik); } -// sample reported cases from the observation model + +/** + * Generate random samples of reported cases + * + * This function generates random samples of reported cases based on the specified model type. + * + * @param reports Vector of expected reports. + * @param rep_phi Array of real values for reporting overdispersion. + * @param model_type Integer indicating the model type (0 for Poisson, >0 for Negative Binomial). + * + * @return An array of integer sampled reports. + */ array[] int report_rng(vector reports, array[] real rep_phi, int model_type) { int t = num_elements(reports); array[t] int sampled_reports; diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index 1f4f65cb9..245f80c49 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -82,7 +82,7 @@ generated quantities { delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, 0, 1, 1 ); - reports[i] = to_row_vector(truncate( + reports[i] = to_row_vector(truncate_obs( to_vector(reports[i]), trunc_rev_cmf, 0) ); } diff --git a/inst/stan/simulate_secondary.stan b/inst/stan/simulate_secondary.stan index d59f1d484..8bd4386f1 100644 --- a/inst/stan/simulate_secondary.stan +++ b/inst/stan/simulate_secondary.stan @@ -73,7 +73,7 @@ generated quantities { delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, 0, 1, 1 ); - secondary = truncate( + secondary = truncate_obs( secondary, trunc_rev_cmf, 0 ); } diff --git a/tests/testthat/test-stan-truncate.R b/tests/testthat/test-stan-truncate.R new file mode 100644 index 000000000..c78f9e8bb --- /dev/null +++ b/tests/testthat/test-stan-truncate.R @@ -0,0 +1,30 @@ +skip_on_cran() +skip_on_os("windows") + +test_that("truncate_obs() can perform truncation as expected", { + reports <- c(10, 20, 30, 40, 50) + trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2) + expected <- c(reports[1], reports[2:5] * trunc_rev_cmf) + expect_equal(truncate_obs(reports, trunc_rev_cmf, FALSE), expected) +}) + +test_that("truncate_obs() can perform reconstruction as expected", { + reports <- c(10, 20, 15, 8, 10) + trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2) + expected <- c(reports[1], reports[2:5] / trunc_rev_cmf) + expect_equal(truncate_obs(reports, trunc_rev_cmf, TRUE), expected) +}) + +test_that("truncate_obs() can handle longer trunc_rev_cmf than reports", { + reports <- c(10, 20, 30) + trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2, 0.1) + expected <- reports * trunc_rev_cmf[3:5] + expect_equal(truncate_obs(reports, trunc_rev_cmf, FALSE), expected) +}) + +test_that("truncate_obs() can handle reconstruction with longer trunc_rev_cmf than reports", { + reports <- c(10, 16, 15) + trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2, 0.1) + expected <- reports / trunc_rev_cmf[3:5] + expect_equal(truncate_obs(reports, trunc_rev_cmf, TRUE), expected) +})