Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve early calculation #903

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and reviewed by @seabbs.
- A bug was fixed where `plot.dist_spec()` wasn't throwing an informative error due to an incomplete check for the max of the specified delay. By @jamesmbaazam in #858 and reviewed by @.
- Updated the early dynamics calculation to use the full linear model if available. Also changd the prior for initial infections to be approximately Poisson. By @sbfnk in # and reviewed by

## Package changes

Expand Down
30 changes: 18 additions & 12 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -455,28 +455,34 @@ create_obs_model <- function(obs = obs_opts(), dates) {
#' @return A list containing `prior_infections` and `prior_growth`.
#' @keywords internal
estimate_early_dynamics <- function(cases, seeding_time) {
first_week <- data.table::data.table(
confirm = cases[seq_len(min(7, length(cases)))],
t = seq_len(min(7, length(cases)))
initial_period <- data.table::data.table(
confirm = cases[seq_len(min(7, seeding_time, length(cases)))],
t = seq_len(min(7, seeding_time, length(cases))) - 1
)[!is.na(confirm)]

# Calculate prior infections
prior_infections <- log(mean(first_week$confirm, na.rm = TRUE))
prior_infections <- ifelse(
is.na(prior_infections) || is.null(prior_infections),
0, prior_infections
)

prior_infections <- 0
# Calculate prior growth
if (seeding_time > 1 && nrow(first_week) > 1) {
if (seeding_time > 1 && nrow(initial_period) > 1) {
safe_lm <- purrr::safely(stats::lm)
prior_growth <- safe_lm(log(confirm) ~ t, data = first_week)[[1]]
prior_growth <- safe_lm(log(confirm) ~ t, data = initial_period)[[1]]
prior_infections <- ifelse(
is.null(prior_growth), 0, prior_growth$coefficients[1]
)
prior_growth <- ifelse(
is.null(prior_growth), 0, prior_growth$coefficients[2]
)
} else {
prior_growth <- 0
}

# Calculate prior infections
if (prior_infections == 0) {
prior_infections <- log(mean(initial_period$confirm, na.rm = TRUE))
if (is.na(prior_infections) || is.null(prior_infections)) {
prior_infections <- 0
}
}

return(list(
prior_infections = prior_infections,
prior_growth = prior_growth
Expand Down
6 changes: 5 additions & 1 deletion inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ transformed parameters {
);
}
profile("infections") {
real frac_obs = get_param(
frac_obs_id, params_fixed_lookup, params_variable_lookup, params_value,
params
);
infections = generate_infections(
R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop,
future_time
future_time, obs_scale, frac_obs
);
}
} else {
Expand Down
5 changes: 4 additions & 1 deletion inst/stan/functions/infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ real update_infectiousness(vector infections, vector gt_rev_pmf,
// generate infections by using Rt = Rt-1 * sum(reversed generation time pmf * infections)
vector generate_infections(vector oR, int uot, vector gt_rev_pmf,
array[] real initial_infections, array[] real initial_growth,
int pop, int ht) {
int pop, int ht, int obs_scale, real frac_obs) {
// time indices and storage
int ot = num_elements(oR);
int nht = ot - ht;
Expand All @@ -32,6 +32,9 @@ vector generate_infections(vector oR, int uot, vector gt_rev_pmf,
vector[ot] infectiousness;
// Initialise infections using daily growth
infections[1] = exp(initial_infections[1]);
if (obs_scale) {
infections[1] = infections[1] / frac_obs;
}
if (uot > 1) {
real growth = exp(initial_growth[1]);
for (s in 2:uot) {
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/functions/rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void rt_lp(array[] real initial_infections, array[] real initial_growth,
bp_effects ~ normal(0, bp_sd[1]);
}
// initial infections
initial_infections ~ normal(prior_infections, 0.2);
initial_infections ~ normal(prior_infections, sqrt(prior_infections));

if (seeding_time > 1) {
initial_growth ~ normal(prior_growth, 0.2);
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ generated quantities {

infections[i] = to_row_vector(generate_infections(
to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i],
initial_growth[i], pop, future_time
initial_growth[i], pop, future_time, obs_scale, frac_obs[i]
));

if (delay_id) {
Expand Down
14 changes: 7 additions & 7 deletions tests/testthat/test-estimate-early-dynamics.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ test_that("estimate_early_dynamics works", {
# Check values
expect_identical(
round(prior_estimates$prior_infections, 2),
4.53
3.21
)
expect_identical(
round(prior_estimates$prior_growth, 2),
Expand All @@ -21,29 +21,29 @@ test_that("estimate_early_dynamics works", {
test_that("estimate_early_dynamics handles NA values correctly", {
cases <- c(10, 20, NA, 40, 50, NA, 70)
prior_estimates <- estimate_early_dynamics(cases, 7)
expect_equal(
prior_estimates$prior_infections,
log(mean(c(10, 20, 40, 50, 70), na.rm = TRUE))
expect_identical(
round(prior_estimates$prior_infections, 2),
2.55
)
expect_true(!is.na(prior_estimates$prior_growth))
})

test_that("estimate_early_dynamics handles exponential growth", {
cases <- 2^(c(0:6)) # Exponential growth
prior_estimates <- estimate_early_dynamics(cases, 7)
expect_equal(prior_estimates$prior_infections, log(mean(cases[1:7])))
expect_equal(prior_estimates$prior_infections, log(2^0))
expect_true(prior_estimates$prior_growth > 0) # Growth should be positive
})

test_that("estimate_early_dynamics handles exponential decline", {
cases <- rev(2^(c(0:6))) # Exponential decline
prior_estimates <- estimate_early_dynamics(cases, 7)
expect_equal(prior_estimates$prior_infections, log(mean(cases[1:7])))
expect_equal(prior_estimates$prior_infections, log(2^6))
expect_true(prior_estimates$prior_growth < 0) # Growth should be negative
})

test_that("estimate_early_dynamics correctly handles seeding time less than 2", {
cases <- c(5, 10, 20) # Less than 7 days of data
prior_estimates <- estimate_early_dynamics(cases, 1)
expect_equal(prior_estimates$prior_growth, 0) # Growth should be 0 if seeding time is <= 1
})
})
16 changes: 8 additions & 8 deletions tests/testthat/test-stan-infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,35 @@ gt_rev_pmf <- get_delay_rev_pmf(
# test generate infections
test_that("generate_infections works as expected", {
expect_equal(
round(generate_infections(c(1, rep(1, 9)), 10, gt_rev_pmf, log(1000), 0, 0, 0), 0),
round(generate_infections(c(1, rep(1, 9)), 10, gt_rev_pmf, log(1000), 0, 0, 0, 0, 0), 0),
c(rep(1000, 10), 995, 996, rep(997, 8))
)
expect_equal(
round(generate_infections(c(1, rep(1.1, 9)), 10, gt_rev_pmf, log(20), 0.03, 0, 0), 0),
round(generate_infections(c(1, rep(1.1, 9)), 10, gt_rev_pmf, log(20), 0.03, 0, 0, 0, 0), 0),
c(20, 21, 21, 22, 23, 23, 24, 25, 25, 26, 24, 27, 28, 29, 30, 30, 31, 32, 33, 34)
)
expect_equal(
round(generate_infections(c(1, rep(1.1, 9)), 10, gt_rev_pmf, log(100), 0, 0, 0), 0),
round(generate_infections(c(1, rep(1.1, 9)), 10, gt_rev_pmf, log(100), 0, 0, 0, 0, 0), 0),
c(rep(100, 10), 99, 110, 112, 115, 119, 122, 126, 130, 134, 138)
)
expect_equal(
round(generate_infections(c(1, rep(1, 9)), 4, gt_rev_pmf, log(500), -0.02, 0, 0), 0),
round(generate_infections(c(1, rep(1, 9)), 4, gt_rev_pmf, log(500), -0.02, 0, 0, 0, 0), 0),
c(500, 490, 480, 471, 382, 403, 408, rep(409, 7))
)
expect_equal(
round(generate_infections(c(1, rep(1.1, 9)), 4, gt_rev_pmf, log(500), 0, 0, 0), 0),
round(generate_infections(c(1, rep(1.1, 9)), 4, gt_rev_pmf, log(500), 0, 0, 0, 0, 0), 0),
c(rep(500, 4), 394, 460, 475, 489, 505, 520, 536, 553, 570, 588)
)
expect_equal(
round(generate_infections(c(1, rep(1, 9)), 1, gt_rev_pmf, log(40), numeric(0), 0, 0), 0),
round(generate_infections(c(1, rep(1, 9)), 1, gt_rev_pmf, log(40), numeric(0), 0, 0, 0, 0), 0),
c(40, 8, 11, 12, 12, rep(13, 6))
)
expect_equal(
round(generate_infections(c(1, rep(1.1, 9)), 1, gt_rev_pmf, log(100), 0.01, 0, 0), 0),
round(generate_infections(c(1, rep(1.1, 9)), 1, gt_rev_pmf, log(100), 0.01, 0, 0, 0, 0), 0),
c(100, 20, 31, 35, 36, 37, 38, 39, 41, 42, 43)
)
expect_equal(
round(generate_infections(c(1, rep(1, 9)), 10, gt_rev_pmf, log(1000), 0, 100000, 4), 0),
round(generate_infections(c(1, rep(1, 9)), 10, gt_rev_pmf, log(1000), 0, 100000, 4, 0, 0), 0),
c(rep(1000, 10), 995, 996, rep(997, 4), 980, 965, 947, 926)
)
})
4 changes: 2 additions & 2 deletions vignettes/estimate_infections.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ These infections are then mapped to observations via discrete convolutions with
The model is initialised before the first observed data point by assuming constant exponential growth for the mean of modelled delays from infection to case report (called `seeding_time` $t_\mathrm{seed}$ in the model):

\begin{align}
I_0 &\sim \mathrm{LogNormal}(I_\mathrm{obs}, 0.2) \\
I_0 &\sim \mathrm{LogNormal}(I_\mathrm{obs}, \sqrt(I_\mathrm{obs})) \\
r &\sim \mathrm{Normal}(r_\mathrm{obs}, 0.2)\\
I_{0 < t \leq t_\mathrm{seed}} &= I_0 \exp \left(r t \right)
\end{align}

where $I_{t}$ is the number of latent infections on day $t$, $r$ is the estimate of the initial growth rate, and $I_\mathrm{obs}$ and $r_\mathrm{obs}$ are estimated from the first week of observed data: $I_\mathrm{obs}$ as the mean of reported cases in the first 7 days (or the mean of all cases if fewer than 7 days of data are given), divided by the prior mean reporting fraction if less than 1 (see [Delays and scaling]); and $r_\mathrm{obs}$ as the point estimate from fitting a linear regression model to the first 7 days of data (or all data if fewer than 7 days of data are given),
where $I_{t}$ is the number of latent infections on day $t$, $r$ is the estimate of the initial growth rate, and $I_\mathrm{obs}$ and $r_\mathrm{obs}$ are estimated from the first week of observed data, respectively, as as the point estimates of intercept and slope from fitting a linear regression model to the first 7 days of data (or all data if fewer than 7 days of data are given),

\begin{equation}
log(C_t) = a + r_\mathrm{obs} t + \epsilon_t
Expand Down
Loading